Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Adam FP32 JIT assembly kernel #39158

Merged
merged 2 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/gen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kAdam)
USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast)
153 changes: 153 additions & 0 deletions paddle/fluid/operators/jit/gen/adam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#include "paddle/fluid/operators/jit/gen/adam.h"

#include <stddef.h> // offsetof

#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void AdamJitCode::loadArgs() {
static constexpr int32_t one_as_float = 0x3f800000;
static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;

mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
mov(eax, one_as_float);
movd(xmm_one, eax);

vbroadcastss(ymm_one, xmm_one); // 1
vbroadcastss(ymm_beta1, xmm_beta1); // beta1
vbroadcastss(ymm_beta2, xmm_beta2); // beta2
vbroadcastss(ymm_lr, xmm_lr); // -lr
vbroadcastss(ymm_eps, xmm_eps); // eps
vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1
vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2

mov(reg_numel_without_tail, reg_numel);
and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible

shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset
shl(reg_numel, 2);

mov(eax, mask_all_ones);
kmovw(k1, eax);

xor_(reg_offset, reg_offset);
}

void AdamJitCode::setTailOpmask() {
mov(r13, rcx);

mov(rcx, reg_numel);
sub(rcx, reg_offset); // get tail numel as float size
shr(rcx, 2); // as elements
mov(r14, 1);
shl(r14, cl); // 2 ^ elements
dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1
kmovw(k1, r14d);

mov(rcx, r13);
}

void AdamJitCode::mainCode() {
// load grad
vmovups(ymm7 | k1, ptr[reg_grad_ptr + reg_offset]);

// beta1 * mom1 + (1 - beta1) * g
vmulps(ymm8 | k1, ymm_one_sub_beta1, ymm7);
vfmadd231ps(ymm8 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);

// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm7 | k1, ymm7, ymm7);
vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]);

// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7);

// sqrt(mom2) + eps
vsqrtps(ymm7 | k1, ymm7);
vaddps(ymm7 | k1, ymm7, ymm3);

// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm7 | k1, ymm8, ymm7);
vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]);

// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
}

void AdamJitCode::genCode() {
static constexpr int64_t main_loop_elems_size =
8 * sizeof(float); // 8 floats in YMM
static constexpr int64_t offset_increment = main_loop_elems_size;
preCode();
loadArgs();

cmp(reg_numel, main_loop_elems_size);
jl("process_tail");

L("main_loop");
{
mainCode();
add(reg_offset, offset_increment);
cmp(reg_numel_without_tail, reg_offset);
jg("main_loop");
}

cmp(reg_numel, reg_offset);
je("end");

L("process_tail");
{
setTailOpmask();
mainCode();
}

L("end");
postCode();
}

class AdamCreator : public JitCodeCreator<adam_attr_t> {
public:
bool CanBeUsed(const adam_attr_t& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const adam_attr_t& attr) const override {
return 96 + 32 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(
const adam_attr_t& attr) const override {
return make_unique<AdamJitCode>(attr, CodeSize(attr));
}
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator);
75 changes: 75 additions & 0 deletions paddle/fluid/operators/jit/gen/adam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#pragma once

#include <string>

#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

class AdamJitCode : public JitCode {
public:
explicit AdamJitCode(const adam_attr_t& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}

DECLARE_JIT_CODE(AdamJitCode);
void genCode() override;
void loadArgs();
void setTailOpmask();
void mainCode();

private:
reg64_t reg_numel{abi_param1};
reg64_t reg_grad_ptr{abi_param2};
reg64_t reg_mom1_ptr{abi_param3};
reg64_t reg_mom2_ptr{abi_param4};
reg64_t reg_param_ptr{abi_param5};
reg64_t reg_mom1_out_ptr{abi_param6};

xmm_t xmm_beta1 = xmm_t(0);
xmm_t xmm_beta2 = xmm_t(1);
xmm_t xmm_lr = xmm_t(2);
xmm_t xmm_eps = xmm_t(3);
xmm_t xmm_one_sub_beta1 = xmm_t(4);
xmm_t xmm_one_sub_beta2 = xmm_t(5);
xmm_t xmm_one = xmm_t(6);

ymm_t ymm_beta1 = ymm_t(0);
ymm_t ymm_beta2 = ymm_t(1);
ymm_t ymm_lr = ymm_t(2);
ymm_t ymm_eps = ymm_t(3);
ymm_t ymm_one_sub_beta1 = ymm_t(4);
ymm_t ymm_one_sub_beta2 = ymm_t(5);
ymm_t ymm_one = ymm_t(6);

reg64_t reg_mom2_out_ptr{r10};
reg64_t reg_param_out_ptr{r11};
reg64_t reg_numel_without_tail{r12};
reg64_t reg_offset{rax};
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/gen/jitcode.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using opmask_t = const Xbyak::Opmask;
using Label = Xbyak::Label;

typedef enum {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kAdam);
ONE_CASE(kHSum);
ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/jit/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os;
}

inline std::ostream& operator<<(std::ostream& os, const adam_attr_t& attr) {
os << "beta1[" << attr.beta1 << "],beta2[" << attr.beta2 << "]";
return os;
}

inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) {
os << "param_height[" << attr.param_height << "],param_width["
<< attr.param_width << "],grad_height[" << attr.grad_height
Expand Down
20 changes: 18 additions & 2 deletions paddle/fluid/operators/jit/kernel_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace jit {
typedef enum {
kNone = 0,
// sort by alphabet
kCRFDecoding = 1,
kEmbSeqPool = 2,
kAdam = 1,
kCRFDecoding,
kEmbSeqPool,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
Expand Down Expand Up @@ -269,6 +270,21 @@ struct SgdTuple {
const sgd_attr_t*);
};

typedef struct adam_attr_s {
float beta1, beta2;
adam_attr_s() = default;
explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {}
} adam_attr_t;

template <typename T>
struct AdamTuple {
static constexpr KernelType kernel_type = kAdam;
typedef T data_type;
typedef adam_attr_t attr_type;
typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*,
const T*, T*, T*, T*);
};

typedef struct matmul_attr_s {
int m, n, k;
void* packed_weight{nullptr};
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/jit/kernel_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width;
}

template <>
int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) {
return static_cast<int64_t>(attr.beta1 + attr.beta2);
}

} // namespace jit
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kAdam)
USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/refer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Adam);
REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast);

Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/operators/jit/refer/refer.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}

template <typename T>
void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
const T* mom1_ptr, const T* mom2_ptr, const T* param_ptr,
T* mom1_out_ptr, T* mom2_out_ptr, T* param_out_ptr) {
for (int i = 0; i < numel; ++i) {
mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
mom2_out_ptr[i] =
beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
param_out_ptr[i] =
param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
}
}

#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
Expand Down Expand Up @@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Adam);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast);

Expand Down
Loading