@@ -406,6 +406,8 @@ struct vk_device_struct {
406406 bool subgroup_ballot;
407407 bool subgroup_clustered;
408408 bool multi_add;
409+ bool shader_int64;
410+ bool buffer_device_address;
409411
410412 bool add_rms_fusion;
411413 uint32_t partials_binding_alignment;
@@ -650,6 +652,7 @@ struct vk_buffer_struct {
650652 vk::MemoryPropertyFlags memory_property_flags;
651653 void * ptr;
652654 size_t size = 0;
655+ vk::DeviceAddress bda_addr {};
653656
654657 vk_device device;
655658
@@ -982,6 +985,7 @@ struct vk_op_argsort_push_constants {
982985};
983986
984987struct vk_op_im2col_push_constants {
988+ uint64_t dst_addr;
985989 uint32_t batch_offset; uint32_t offset_delta;
986990 uint32_t IC;
987991 uint32_t IW; uint32_t IH;
@@ -995,6 +999,7 @@ struct vk_op_im2col_push_constants {
995999};
9961000
9971001struct vk_op_im2col_3d_push_constants {
1002+ uint64_t dst_addr;
9981003 uint32_t nb10;
9991004 uint32_t nb11;
10001005 uint32_t nb12;
@@ -1946,10 +1951,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
19461951 return buf;
19471952 }
19481953
1954+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
1955+ vk::MemoryAllocateFlags mem_flags {};
1956+ if (device->buffer_device_address) {
1957+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
1958+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
1959+ }
1960+
19491961 vk::BufferCreateInfo buffer_create_info{
19501962 vk::BufferCreateFlags(),
19511963 size,
1952- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst ,
1964+ usage_flags ,
19531965 vk::SharingMode::eExclusive,
19541966 0,
19551967 nullptr,
@@ -1961,6 +1973,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
19611973
19621974 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
19631975
1976+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
1977+
19641978 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
19651979 const auto & req_flags = *it;
19661980
@@ -1972,7 +1986,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
19721986 buf->memory_property_flags = req_flags;
19731987
19741988 try {
1975- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
1989+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
19761990 break;
19771991 } catch (const vk::SystemError& e) {
19781992 // loop and retry
@@ -2000,6 +2014,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20002014 buf->device = device;
20012015 buf->size = size;
20022016
2017+ if (device->buffer_device_address) {
2018+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2019+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2020+ }
2021+
20032022#ifdef GGML_VULKAN_MEMORY_DEBUG
20042023 device->memory_logger->log_allocation(buf, size);
20052024#endif
@@ -3447,14 +3466,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
34473466
34483467 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
34493468
3450- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3451- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3452- if (device->float_controls_rte_fp16) {
3453- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3454- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3469+ #define IM2COL(bda) \
3470+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3471+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3472+ if (device->float_controls_rte_fp16) { \
3473+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3474+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3475+ } else { \
3476+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3477+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3478+ }
3479+ if (device->shader_int64 && device->buffer_device_address) {
3480+ IM2COL(_bda)
34553481 } else {
3456- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3457- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3482+ IM2COL()
34583483 }
34593484
34603485 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -3933,6 +3958,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
39333958 device->vendor_id != VK_VENDOR_ID_INTEL &&
39343959 getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
39353960
3961+ device->shader_int64 = device_features2.features.shaderInt64;
3962+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
3963+
39363964 if (device->subgroup_size_control) {
39373965 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
39383966 device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -9290,7 +9318,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
92909318
92919319 const uint32_t pelements = OW * KW * KH;
92929320
9321+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9322+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9323+
9324+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9325+
92939326 ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9327+ dst_addr,
92949328 batch_offset, offset_delta,
92959329 IC, IW, IH, OW, OH, KW, KH,
92969330 pelements,
@@ -9326,8 +9360,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
93269360 const int64_t OH = ne2;
93279361 const int64_t OW = ne1;
93289362
9363+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9364+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9365+
9366+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9367+
93299368 vk_op_im2col_3d_push_constants pc {};
93309369
9370+ pc.dst_addr = dst_addr;
93319371 pc.nb10 = nb10 / ggml_type_size(src1->type);
93329372 pc.nb11 = nb11 / ggml_type_size(src1->type);
93339373 pc.nb12 = nb12 / ggml_type_size(src1->type);
0 commit comments