|
4 | 4 |
|
5 | 5 | #include <altivec.h> |
6 | 6 | #include <cmath> |
| 7 | +#include <algorithm> |
7 | 8 | #include <torch/all.h> |
8 | 9 |
|
9 | 10 | namespace vec_op { |
@@ -62,6 +63,10 @@ typedef struct f32x4x4_t { |
62 | 63 | __vector float val[4]; |
63 | 64 | } f32x4x4_t; |
64 | 65 |
|
| 66 | +typedef struct i32x4x4_t { |
| 67 | + __vector int32_t val[4]; |
| 68 | +} i32x4x4_t; |
| 69 | + |
65 | 70 | struct FP32Vec8; |
66 | 71 | struct FP32Vec16; |
67 | 72 |
|
@@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> { |
98 | 103 | vec_xst(reg.val[0], 0, (signed short*)ptr); |
99 | 104 | vec_xst(reg.val[1], 16, (signed short*)ptr); |
100 | 105 | } |
| 106 | + |
| 107 | + void save(void* ptr, const int elem_num) const { |
| 108 | + const int clamped_elem = std::max(0, std::min(elem_num, 16)); |
| 109 | + |
| 110 | + // Calculate elements to store in each 128-bit part (8 elements each) |
| 111 | + const int elements_val0 = std::min(clamped_elem, 8); |
| 112 | + const int elements_val1 = std::max(clamped_elem - 8, 0); |
| 113 | + |
| 114 | + // Convert elements to bytes (2 bytes per element) |
| 115 | + const size_t bytes_val0 = elements_val0 * sizeof(signed short); |
| 116 | + const size_t bytes_val1 = elements_val1 * sizeof(signed short); |
| 117 | + |
| 118 | + signed short* dest = static_cast<signed short*>(ptr); |
| 119 | + // Store the first part using vec_xst_len |
| 120 | + if (bytes_val0 > 0) { |
| 121 | + vec_xst_len(reg.val[0], dest, bytes_val0); |
| 122 | + } |
| 123 | + // Store the second part if needed |
| 124 | + if (bytes_val1 > 0) { |
| 125 | + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); |
| 126 | + } |
| 127 | + } |
101 | 128 | }; |
102 | 129 |
|
103 | 130 | const static __vector signed short zero = vec_splats((signed short)0); |
@@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> { |
257 | 284 | } |
258 | 285 | }; |
259 | 286 |
|
| 287 | +struct INT32Vec16 : public Vec<INT32Vec16> { |
| 288 | + constexpr static int VEC_ELEM_NUM = 16; |
| 289 | + union AliasReg { |
| 290 | + i32x4x4_t reg; |
| 291 | + int32_t values[VEC_ELEM_NUM]; |
| 292 | + }; |
| 293 | + |
| 294 | + i32x4x4_t reg; |
| 295 | + |
| 296 | + explicit INT32Vec16(const void* data_ptr) { |
| 297 | + reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr)); |
| 298 | + reg.val[1] = |
| 299 | + vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr)); |
| 300 | + reg.val[2] = |
| 301 | + vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr)); |
| 302 | + reg.val[3] = |
| 303 | + vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr)); |
| 304 | + } |
| 305 | + |
| 306 | + void save(int32_t* ptr) const { |
| 307 | + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); |
| 308 | + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); |
| 309 | + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); |
| 310 | + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); |
| 311 | + } |
| 312 | + |
| 313 | + void save(int32_t* ptr, const int elem_num) const { |
| 314 | + const int elements_in_chunk1 = |
| 315 | + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; |
| 316 | + const int elements_in_chunk2 = |
| 317 | + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; |
| 318 | + const int elements_in_chunk3 = |
| 319 | + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; |
| 320 | + const int elements_in_chunk4 = |
| 321 | + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; |
| 322 | + |
| 323 | + const size_t bytes_chunk1 = |
| 324 | + static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t)); |
| 325 | + const size_t bytes_chunk2 = |
| 326 | + static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t)); |
| 327 | + const size_t bytes_chunk3 = |
| 328 | + static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t)); |
| 329 | + const size_t bytes_chunk4 = |
| 330 | + static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t)); |
| 331 | + |
| 332 | + vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1); |
| 333 | + vec_xst_len(reg.val[1], |
| 334 | + reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16), |
| 335 | + bytes_chunk2); |
| 336 | + vec_xst_len(reg.val[2], |
| 337 | + reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32), |
| 338 | + bytes_chunk3); |
| 339 | + vec_xst_len(reg.val[3], |
| 340 | + reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48), |
| 341 | + bytes_chunk4); |
| 342 | + } |
| 343 | +}; |
| 344 | + |
260 | 345 | struct FP32Vec16 : public Vec<FP32Vec16> { |
261 | 346 | constexpr static int VEC_ELEM_NUM = 16; |
262 | 347 | union AliasReg { |
@@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> { |
319 | 404 |
|
320 | 405 | explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} |
321 | 406 |
|
| 407 | + explicit FP32Vec16(const INT32Vec16& v) { |
| 408 | + reg.val[0] = vec_ctf(v.reg.val[0], 0); |
| 409 | + reg.val[1] = vec_ctf(v.reg.val[1], 0); |
| 410 | + reg.val[2] = vec_ctf(v.reg.val[2], 0); |
| 411 | + reg.val[3] = vec_ctf(v.reg.val[3], 0); |
| 412 | + } |
| 413 | + |
322 | 414 | FP32Vec16 operator*(const FP32Vec16& b) const { |
323 | 415 | return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), |
324 | 416 | vec_mul(reg.val[1], b.reg.val[1]), |
@@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> { |
347 | 439 | vec_div(reg.val[3], b.reg.val[3])})); |
348 | 440 | } |
349 | 441 |
|
| 442 | + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { |
| 443 | + return FP32Vec16(f32x4x4_t( |
| 444 | + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), |
| 445 | + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), |
| 446 | + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), |
| 447 | + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); |
| 448 | + } |
| 449 | + |
| 450 | + FP32Vec16 max(const FP32Vec16& b) const { |
| 451 | + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), |
| 452 | + vec_max(reg.val[1], b.reg.val[1]), |
| 453 | + vec_max(reg.val[2], b.reg.val[2]), |
| 454 | + vec_max(reg.val[3], b.reg.val[3])})); |
| 455 | + } |
| 456 | + |
| 457 | + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { |
| 458 | + FP32Vec16 result; |
| 459 | + |
| 460 | + // Create a vector of element indices for each chunk |
| 461 | + __vector unsigned int indices = {0, 1, 2, 3}; |
| 462 | + __vector unsigned int elem_num_vec = |
| 463 | + vec_splats(static_cast<unsigned int>(elem_num)); |
| 464 | + |
| 465 | + // Compute masks for each chunk |
| 466 | + __vector unsigned int chunk_offset0 = {0, 0, 0, |
| 467 | + 0}; // Chunk 0: Elements 0-3 |
| 468 | + __vector unsigned int chunk_offset1 = {4, 4, 4, |
| 469 | + 4}; // Chunk 1: Elements 4-7 |
| 470 | + __vector unsigned int chunk_offset2 = {8, 8, 8, |
| 471 | + 8}; // Chunk 2: Elements 8-11 |
| 472 | + __vector unsigned int chunk_offset3 = {12, 12, 12, |
| 473 | + 12}; // Chunk 3: Elements 12-15 |
| 474 | + |
| 475 | + // Compute masks for each chunk |
| 476 | + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); |
| 477 | + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); |
| 478 | + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); |
| 479 | + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); |
| 480 | + |
| 481 | + // Apply masks to compute the result for each chunk |
| 482 | + result.reg.val[0] = vec_sel(this->reg.val[0], |
| 483 | + vec_max(this->reg.val[0], b.reg.val[0]), mask0); |
| 484 | + result.reg.val[1] = vec_sel(this->reg.val[1], |
| 485 | + vec_max(this->reg.val[1], b.reg.val[1]), mask1); |
| 486 | + result.reg.val[2] = vec_sel(this->reg.val[2], |
| 487 | + vec_max(this->reg.val[2], b.reg.val[2]), mask2); |
| 488 | + result.reg.val[3] = vec_sel(this->reg.val[3], |
| 489 | + vec_max(this->reg.val[3], b.reg.val[3]), mask3); |
| 490 | + |
| 491 | + return FP32Vec16(result.reg); |
| 492 | + } |
| 493 | + |
| 494 | + FP32Vec16 min(const FP32Vec16& b) const { |
| 495 | + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), |
| 496 | + vec_min(reg.val[1], b.reg.val[1]), |
| 497 | + vec_min(reg.val[2], b.reg.val[2]), |
| 498 | + vec_min(reg.val[3], b.reg.val[3])})); |
| 499 | + } |
| 500 | + |
| 501 | + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { |
| 502 | + FP32Vec16 result; |
| 503 | + |
| 504 | + vector unsigned int indices = {0, 1, 2, 3}; |
| 505 | + vector unsigned int elem_num_vec = |
| 506 | + vec_splats(static_cast<unsigned int>(elem_num)); |
| 507 | + |
| 508 | + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; |
| 509 | + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; |
| 510 | + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; |
| 511 | + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; |
| 512 | + |
| 513 | + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); |
| 514 | + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); |
| 515 | + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); |
| 516 | + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); |
| 517 | + |
| 518 | + result.reg.val[0] = vec_sel(this->reg.val[0], |
| 519 | + vec_min(this->reg.val[0], b.reg.val[0]), mask0); |
| 520 | + result.reg.val[1] = vec_sel(this->reg.val[1], |
| 521 | + vec_min(this->reg.val[1], b.reg.val[1]), mask1); |
| 522 | + result.reg.val[2] = vec_sel(this->reg.val[2], |
| 523 | + vec_min(this->reg.val[2], b.reg.val[2]), mask2); |
| 524 | + result.reg.val[3] = vec_sel(this->reg.val[3], |
| 525 | + vec_min(this->reg.val[3], b.reg.val[3]), mask3); |
| 526 | + |
| 527 | + return FP32Vec16(result.reg); |
| 528 | + } |
| 529 | + |
| 530 | + FP32Vec16 abs() const { |
| 531 | + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), |
| 532 | + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); |
| 533 | + } |
| 534 | + |
| 535 | + float reduce_max() { |
| 536 | + __vector float max01 = vec_max(reg.val[0], reg.val[1]); |
| 537 | + __vector float max23 = vec_max(reg.val[2], reg.val[3]); |
| 538 | + __vector float max_all = vec_max(max01, max23); |
| 539 | + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); |
| 540 | + temp = vec_max(temp, vec_sld(temp, temp, 4)); |
| 541 | + return vec_extract(temp, 0); |
| 542 | + } |
| 543 | + |
| 544 | + float reduce_min() { |
| 545 | + __vector float min01 = vec_min(reg.val[0], reg.val[1]); |
| 546 | + __vector float min23 = vec_min(reg.val[2], reg.val[3]); |
| 547 | + __vector float min_all = vec_min(min01, min23); |
| 548 | + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); |
| 549 | + temp = vec_min(temp, vec_sld(temp, temp, 4)); |
| 550 | + return vec_extract(temp, 0); |
| 551 | + } |
| 552 | + |
350 | 553 | float reduce_sum() const { |
351 | 554 | AliasReg ar; |
352 | 555 | ar.reg = reg; |
@@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> { |
377 | 580 | vec_xst(reg.val[2], 32, ptr); |
378 | 581 | vec_xst(reg.val[3], 48, ptr); |
379 | 582 | } |
| 583 | + |
| 584 | + void save(float* ptr, const int elem_num) const { |
| 585 | + const int elements_in_chunk1 = |
| 586 | + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; |
| 587 | + const int elements_in_chunk2 = |
| 588 | + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; |
| 589 | + const int elements_in_chunk3 = |
| 590 | + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; |
| 591 | + const int elements_in_chunk4 = |
| 592 | + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; |
| 593 | + |
| 594 | + const size_t bytes_chunk1 = |
| 595 | + static_cast<size_t>(elements_in_chunk1 * sizeof(float)); |
| 596 | + const size_t bytes_chunk2 = |
| 597 | + static_cast<size_t>(elements_in_chunk2 * sizeof(float)); |
| 598 | + const size_t bytes_chunk3 = |
| 599 | + static_cast<size_t>(elements_in_chunk3 * sizeof(float)); |
| 600 | + const size_t bytes_chunk4 = |
| 601 | + static_cast<size_t>(elements_in_chunk4 * sizeof(float)); |
| 602 | + |
| 603 | + vec_xst_len(reg.val[0], ptr, bytes_chunk1); |
| 604 | + vec_xst_len(reg.val[1], |
| 605 | + reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16), |
| 606 | + bytes_chunk2); |
| 607 | + vec_xst_len(reg.val[2], |
| 608 | + reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32), |
| 609 | + bytes_chunk3); |
| 610 | + vec_xst_len(reg.val[3], |
| 611 | + reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48), |
| 612 | + bytes_chunk4); |
| 613 | + } |
| 614 | +}; |
| 615 | + |
| 616 | +struct INT8Vec16 : public Vec<INT8Vec16> { |
| 617 | + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 |
| 618 | + |
| 619 | + union AliasReg { |
| 620 | + __vector signed char reg; |
| 621 | + int8_t values[VEC_NUM_ELEM]; |
| 622 | + }; |
| 623 | + |
| 624 | + __vector signed char reg; |
| 625 | + |
| 626 | + explicit INT8Vec16(const FP32Vec16& vec) { |
| 627 | + __vector signed int ret[4]; |
| 628 | + ret[0] = vec_cts(vec.reg.val[0], 0); |
| 629 | + ret[1] = vec_cts(vec.reg.val[1], 0); |
| 630 | + ret[2] = vec_cts(vec.reg.val[2], 0); |
| 631 | + ret[3] = vec_cts(vec.reg.val[3], 0); |
| 632 | + |
| 633 | + __vector signed short packed1 = vec_packs(ret[0], ret[1]); |
| 634 | + __vector signed short packed2 = vec_packs(ret[2], ret[3]); |
| 635 | + |
| 636 | + reg = vec_packs(packed1, packed2); |
| 637 | + } |
| 638 | + |
| 639 | + void save(void* ptr) const { |
| 640 | + *reinterpret_cast<__vector signed char*>(ptr) = reg; |
| 641 | + } |
| 642 | + void save(signed char* ptr, const int elem_num) { |
| 643 | + vec_xst_len(reg, ptr, static_cast<size_t>(elem_num)); |
| 644 | + } |
380 | 645 | }; |
381 | 646 |
|
382 | 647 | template <typename T> |
|
0 commit comments