diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt
index 99244ea9bd919a..79fcb780feb931 100644
--- a/paddle/fluid/operators/jit/gen/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt
@@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
 USE_JITKERNEL_GEN(kHMax)
 USE_JITKERNEL_GEN(kHSum)
 USE_JITKERNEL_GEN(kEmbSeqPool)
+USE_JITKERNEL_GEN(kAdam)
 USE_JITKERNEL_GEN(kSgd)
 USE_JITKERNEL_GEN(kVBroadcast)
diff --git a/paddle/fluid/operators/jit/gen/adam.cc b/paddle/fluid/operators/jit/gen/adam.cc
new file mode 100644
index 00000000000000..7e8cb7f59eed61
--- /dev/null
+++ b/paddle/fluid/operators/jit/gen/adam.cc
@@ -0,0 +1,153 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. */
+
+#include "paddle/fluid/operators/jit/gen/adam.h"
+
+#include <stddef.h>  // offsetof
+
+#include "paddle/fluid/operators/jit/registry.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+namespace jit {
+namespace gen {
+
+void AdamJitCode::loadArgs() {
+  static constexpr int32_t one_as_float = 0x3f800000;
+  static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
+  static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
+  static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;
+
+  mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
+  mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
+  mov(eax, one_as_float);
+  movd(xmm_one, eax);
+
+  vbroadcastss(ymm_one, xmm_one);                 // 1
+  vbroadcastss(ymm_beta1, xmm_beta1);             // beta1
+  vbroadcastss(ymm_beta2, xmm_beta2);             // beta2
+  vbroadcastss(ymm_lr, xmm_lr);                   // -lr
+  vbroadcastss(ymm_eps, xmm_eps);                 // eps
+  vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1);  // 1 - beta1
+  vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2);  // 1 - beta2
+
+  mov(reg_numel_without_tail, reg_numel);
+  and_(reg_numel_without_tail, mask_8_divisible);  // make it 8-divisible
+
+  shl(reg_numel_without_tail, 2);  // * 4 to treat it as float offset
+  shl(reg_numel, 2);
+
+  mov(eax, mask_all_ones);
+  kmovw(k1, eax);
+
+  xor_(reg_offset, reg_offset);
+}
+
+void AdamJitCode::setTailOpmask() {
+  mov(r13, rcx);
+
+  mov(rcx, reg_numel);
+  sub(rcx, reg_offset);  // get tail numel as float size
+  shr(rcx, 2);           // as elements
+  mov(r14, 1);
+  shl(r14, cl);  // 2 ^ elements
+  dec(r14);      // 2 ^ elements - 1, so numel first bits are set to 1
+  kmovw(k1, r14d);
+
+  mov(rcx, r13);
+}
+
+void AdamJitCode::mainCode() {
+  // load grad
+  vmovups(ymm7 | k1, ptr[reg_grad_ptr + reg_offset]);
+
+  // beta1 * mom1 + (1 - beta1) * g
+  vmulps(ymm8 | k1, ymm_one_sub_beta1, ymm7);
+  vfmadd231ps(ymm8 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);
+
+  // beta2 * mom2 + (1 - beta2) * g * g
+  vmulps(ymm7 | k1, ymm7, ymm7);
+  vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
+  vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]);
+
+  // store mom1 and mom2
+  vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
+  vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7);
+
+  // sqrt(mom2) + eps
+  vsqrtps(ymm7 | k1, ymm7);
+  vaddps(ymm7 | k1, ymm7, ymm3);
+
+  // p + (-lr) * (mom1 / sqrt(mom2) + eps)
+  vdivps(ymm7 | k1, ymm8, ymm7);
+  vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]);
+
+  // store p
+  vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
+}
+
+void AdamJitCode::genCode() {
+  static constexpr int64_t main_loop_elems_size =
+      8 * sizeof(float);  // 8 floats in YMM
+  static constexpr int64_t offset_increment = main_loop_elems_size;
+  preCode();
+  loadArgs();
+
+  cmp(reg_numel, main_loop_elems_size);
+  jl("process_tail");
+
+  L("main_loop");
+  {
+    mainCode();
+    add(reg_offset, offset_increment);
+    cmp(reg_numel_without_tail, reg_offset);
+    jg("main_loop");
+  }
+
+  cmp(reg_numel, reg_offset);
+  je("end");
+
+  L("process_tail");
+  {
+    setTailOpmask();
+    mainCode();
+  }
+
+  L("end");
+  postCode();
+}
+
+class AdamCreator : public JitCodeCreator<adam_attr_t> {
+ public:
+  bool CanBeUsed(const adam_attr_t& attr) const override {
+    return platform::MayIUse(platform::avx512f);
+  }
+  size_t CodeSize(const adam_attr_t& attr) const override {
+    return 96 + 32 * 8;
+  }
+  std::unique_ptr<GenBase> CreateJitCode(
+      const adam_attr_t& attr) const override {
+    return make_unique<AdamJitCode>(attr, CodeSize(attr));
+  }
+};
+
+}  // namespace gen
+}  // namespace jit
+}  // namespace operators
+}  // namespace paddle
+
+namespace gen = paddle::operators::jit::gen;
+
+REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator);
diff --git a/paddle/fluid/operators/jit/gen/adam.h b/paddle/fluid/operators/jit/gen/adam.h
new file mode 100644
index 00000000000000..86a38e97ece021
--- /dev/null
+++ b/paddle/fluid/operators/jit/gen/adam.h
@@ -0,0 +1,75 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. */
+
+#pragma once
+
+#include <string>
+
+#include "glog/logging.h"
+#include "paddle/fluid/operators/jit/gen/jitcode.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace operators {
+namespace jit {
+namespace gen {
+
+class AdamJitCode : public JitCode {
+ public:
+  explicit AdamJitCode(const adam_attr_t& attr, size_t code_size = 256 * 1024,
+                       void* code_ptr = nullptr)
+      : JitCode(code_size, code_ptr) {
+    this->genCode();
+  }
+
+  DECLARE_JIT_CODE(AdamJitCode);
+  void genCode() override;
+  void loadArgs();
+  void setTailOpmask();
+  void mainCode();
+
+ private:
+  reg64_t reg_numel{abi_param1};
+  reg64_t reg_grad_ptr{abi_param2};
+  reg64_t reg_mom1_ptr{abi_param3};
+  reg64_t reg_mom2_ptr{abi_param4};
+  reg64_t reg_param_ptr{abi_param5};
+  reg64_t reg_mom1_out_ptr{abi_param6};
+
+  xmm_t xmm_beta1 = xmm_t(0);
+  xmm_t xmm_beta2 = xmm_t(1);
+  xmm_t xmm_lr = xmm_t(2);
+  xmm_t xmm_eps = xmm_t(3);
+  xmm_t xmm_one_sub_beta1 = xmm_t(4);
+  xmm_t xmm_one_sub_beta2 = xmm_t(5);
+  xmm_t xmm_one = xmm_t(6);
+
+  ymm_t ymm_beta1 = ymm_t(0);
+  ymm_t ymm_beta2 = ymm_t(1);
+  ymm_t ymm_lr = ymm_t(2);
+  ymm_t ymm_eps = ymm_t(3);
+  ymm_t ymm_one_sub_beta1 = ymm_t(4);
+  ymm_t ymm_one_sub_beta2 = ymm_t(5);
+  ymm_t ymm_one = ymm_t(6);
+
+  reg64_t reg_mom2_out_ptr{r10};
+  reg64_t reg_param_out_ptr{r11};
+  reg64_t reg_numel_without_tail{r12};
+  reg64_t reg_offset{rax};
+};
+
+}  // namespace gen
+}  // namespace jit
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h
index 23650c8efc73b0..bd84368a573881 100644
--- a/paddle/fluid/operators/jit/gen/jitcode.h
+++ b/paddle/fluid/operators/jit/gen/jitcode.h
@@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
 using xmm_t = const Xbyak::Xmm;
 using ymm_t = const Xbyak::Ymm;
 using zmm_t = const Xbyak::Zmm;
+using opmask_t = const Xbyak::Opmask;
 using Label = Xbyak::Label;
 
 typedef enum {
diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc
index 2085aa41e3b90d..4bdb65030590fd 100644
--- a/paddle/fluid/operators/jit/helper.cc
+++ b/paddle/fluid/operators/jit/helper.cc
@@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
     ONE_CASE(kSeqPool);
     ONE_CASE(kMatMul);
     ONE_CASE(kHMax);
+    ONE_CASE(kAdam);
     ONE_CASE(kHSum);
     ONE_CASE(kStrideASum);
     ONE_CASE(kSoftmax);
diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h
index 0791bb5810526c..f217cf6e778547 100644
--- a/paddle/fluid/operators/jit/helper.h
+++ b/paddle/fluid/operators/jit/helper.h
@@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
   return os;
 }
 
+inline std::ostream& operator<<(std::ostream& os, const adam_attr_t& attr) {
+  os << "beta1[" << attr.beta1 << "],beta2[" << attr.beta2 << "]";
+  return os;
+}
+
 inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) {
   os << "param_height[" << attr.param_height << "],param_width["
      << attr.param_width << "],grad_height[" << attr.grad_height
diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index 6e0393b820f378..40ea04d3c2791d 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -24,8 +24,9 @@ namespace jit {
 typedef enum {
   kNone = 0,
   // sort by alphabet
-  kCRFDecoding = 1,
-  kEmbSeqPool = 2,
+  kAdam = 1,
+  kCRFDecoding,
+  kEmbSeqPool,
   kGRUH1,
   kGRUHtPart1,
   kGRUHtPart2,
@@ -269,6 +270,21 @@ struct SgdTuple {
                             const sgd_attr_t*);
 };
 
+typedef struct adam_attr_s {
+  float beta1, beta2;
+  adam_attr_s() = default;
+  explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {}
+} adam_attr_t;
+
+template <typename T>
+struct AdamTuple {
+  static constexpr KernelType kernel_type = kAdam;
+  typedef T data_type;
+  typedef adam_attr_t attr_type;
+  typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*,
+                            const T*, T*, T*, T*);
+};
+
 typedef struct matmul_attr_s {
   int m, n, k;
   void* packed_weight{nullptr};
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
index a7b1addeb5ded7..4f652002bc7455 100644
--- a/paddle/fluid/operators/jit/kernel_key.cc
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
   return attr.grad_width;
 }
 
+template <>
+int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) {
+  return static_cast<int64_t>(attr.beta1 + attr.beta2);
+}
+
 }  // namespace jit
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt
index 7133f596620410..e4e3263e01ebae 100644
--- a/paddle/fluid/operators/jit/refer/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt
@@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
 USE_JITKERNEL_REFER(kStrideASum)
 USE_JITKERNEL_REFER(kSoftmax)
 USE_JITKERNEL_REFER(kEmbSeqPool)
+USE_JITKERNEL_REFER(kAdam)
 USE_JITKERNEL_REFER(kSgd)
 USE_JITKERNEL_REFER(kVBroadcast)
diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc
index 460cb6c58076d7..8669bfe37232bf 100644
--- a/paddle/fluid/operators/jit/refer/refer.cc
+++ b/paddle/fluid/operators/jit/refer/refer.cc
@@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
 REGISTER_REFER_KERNEL(StrideASum);
 REGISTER_REFER_KERNEL(Softmax);
 REGISTER_REFER_KERNEL(EmbSeqPool);
+REGISTER_REFER_KERNEL(Adam);
 REGISTER_REFER_KERNEL(Sgd);
 REGISTER_REFER_KERNEL(VBroadcast);
 
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index 42fb7b4f279c22..3545b35a703f8c 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
   }
 }
 
+template <typename T>
+void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
+          const T* mom1_ptr, const T* mom2_ptr, const T* param_ptr,
+          T* mom1_out_ptr, T* mom2_out_ptr, T* param_out_ptr) {
+  for (int i = 0; i < numel; ++i) {
+    mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
+    mom2_out_ptr[i] =
+        beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
+    param_out_ptr[i] =
+        param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
+  }
+}
+
 #define DECLARE_REFER_KERNEL(name)                          \
   template <typename T>                                     \
   class name##Kernel : public ReferKernel<name##Tuple<T>> { \
@@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
 DECLARE_REFER_KERNEL(MatMul);
 DECLARE_REFER_KERNEL(Softmax);
 DECLARE_REFER_KERNEL(EmbSeqPool);
+DECLARE_REFER_KERNEL(Adam);
 DECLARE_REFER_KERNEL(Sgd);
 DECLARE_REFER_KERNEL(VBroadcast);
 
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index ff68565637c5a9..675db4a72bda33 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -841,6 +841,72 @@ void TestKernelStrideScal() {
   }
 }
 
+template <typename KernelTuple, typename PlaceType>
+void TestKernelAdam() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  const T lr = 0.1;
+  const T beta1 = 0.99;
+  const T beta2 = 0.95;
+  const T beta1_pow = beta1 * beta1;
+  const T beta2_pow = beta2 * beta2;
+
+  const T epsilon = 0.000001;
+  const int64_t numel = 123;
+
+  T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow));
+  T eps = epsilon * sqrt(1 - beta2_pow);
+
+  std::vector<T> param(numel);
+  std::vector<T> grad(numel);
+  std::vector<T> mom1(numel);
+  std::vector<T> mom2(numel);
+
+  std::vector<T> param_out(param.size());
+  std::vector<T> mom1_out(mom1.size());
+  std::vector<T> mom2_out(mom2.size());
+
+  RandomVec<T>(numel, param.data(), 0.5f);
+  RandomVec<T>(numel, grad.data(), 0.5f);
+  RandomVec<T>(numel, mom1.data(), 0.5f);
+  RandomVec<T>(numel, mom2.data(), 0.5f);
+
+  auto ref = jit::GetReferFunc<KernelTuple>();
+  EXPECT_TRUE(ref != nullptr);
+  jit::adam_attr_t attr(beta1, beta2);
+  ref(beta1, beta2, -learning_rate, eps, numel, grad.data(), mom1.data(),
+      mom2.data(), param.data(), mom1_out.data(), mom2_out.data(),
+      param_out.data());
+
+  auto verifier = [](
+      const typename KernelTuple::func_type tgt, T beta1, T beta2, T lr, T eps,
+      int64_t numel, const std::vector<T>& grad, const std::vector<T>& mom1,
+      const std::vector<T>& mom2, const std::vector<T>& param,
+      const std::vector<T>& ref_mom1_out, const std::vector<T>& ref_mom2_out,
+      const std::vector<T>& ref_param_out) {
+    EXPECT_TRUE(tgt != nullptr);
+    EXPECT_EQ(param.size(), static_cast<size_t>(numel));
+    EXPECT_EQ(grad.size(), static_cast<size_t>(numel));
+    EXPECT_EQ(mom1.size(), static_cast<size_t>(numel));
+    EXPECT_EQ(mom2.size(), static_cast<size_t>(numel));
+
+    std::vector<T> jit_mom1_out(ref_mom1_out.size());
+    std::vector<T> jit_mom2_out(ref_mom2_out.size());
+    std::vector<T> jit_param_out(ref_param_out.size());
+
+    tgt(beta1, beta2, -lr, eps, numel, grad.data(), mom1.data(), mom2.data(),
+        param.data(), jit_mom1_out.data(), jit_mom2_out.data(),
+        jit_param_out.data());
+
+    ExpectEQ<T>(ref_mom1_out.data(), jit_mom1_out.data(), numel);
+    ExpectEQ<T>(ref_mom2_out.data(), jit_mom2_out.data(), numel);
+    ExpectEQ<T>(ref_param_out.data(), jit_param_out.data(), numel);
+  };
+  TestAllImpls<KernelTuple, PlaceType>(
+      attr, verifier, beta1, beta2, learning_rate, eps, numel, grad, mom1, mom2,
+      param, mom1_out, mom2_out, param_out);
+}
+
 template <typename KernelTuple, typename PlaceType>
 void TestKernelSgd() {
   using T = typename KernelTuple::data_type;
@@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) {
 #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
   EXPECT_EQ(jitcreators.size(), 0UL);
 #else
-  EXPECT_EQ(jitcreators.size(), 25UL);
+  EXPECT_EQ(jitcreators.size(), 26UL);
 #endif
 }
 
@@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) {
 
 TEST(JITKernel_pool, refer) {
   const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
-  EXPECT_EQ(kers.size(), 31UL);
+  EXPECT_EQ(kers.size(), 32UL);
 }
 
 // test helper
@@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) {
       << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity)
       << jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu)
       << jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd)
-      << jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare)
-      << jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh);
-  EXPECT_EQ(out.str().size(), 234UL);
+      << jit::to_string(jit::kAdam) << jit::to_string(jit::kVSigmoid)
+      << jit::to_string(jit::kVSquare) << jit::to_string(jit::kVSub)
+      << jit::to_string(jit::kVTanh);
+  EXPECT_EQ(out.str().size(), 239UL);
 
   // SeqPoolTypes
   out.str("");
@@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) {
   EXPECT_TRUE(key4 != key5);
 }
 
+TEST(JITKernel_key, adam) {
+  jit::adam_attr_t attr1(0.4f, 0.9f);
+  jit::adam_attr_t attr2(0.4f, 0.9f);
+  jit::adam_attr_t attr3(0.1f, 0.3f);
+
+  auto key1 = jit::JitCodeKey<jit::adam_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::adam_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::adam_attr_t>(attr3);
+
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 != key3);
+}
+
 TEST(JITKernel_key, sgd) {
   jit::sgd_attr_t attr1(1, 2, 3, 4, 5);
   jit::sgd_attr_t attr2(1, 2, 3, 4, 5);
@@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) {
   EXPECT_TRUE(key4 != key5);
 }
 
-// test kernerls
+// test kernels
 #define TestKernelVMul TestKernelXYZN
 #define TestKernelVAdd TestKernelXYZN
 #define TestKernelVAddRelu TestKernelXYZN
@@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool);
 TEST_CPU_KERNEL(EmbSeqPool);
 TEST_CPU_KERNEL(MatMul);
 TEST_CPU_KERNEL(Softmax);
+TEST_CPU_KERNEL(Adam);
 TEST_CPU_KERNEL(Sgd);
 TEST_CPU_KERNEL(VBroadcast);
 
diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h
index bcc314cd57c017..bdeaa106282d25 100644
--- a/paddle/fluid/operators/optimizers/adam_op.h
+++ b/paddle/fluid/operators/optimizers/adam_op.h
@@ -20,9 +20,11 @@ limitations under the License. */
 #include <vector>
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/threadpool.h"
+#include "paddle/fluid/operators/jit/kernels.h"
 #include "paddle/fluid/operators/math/algorithm.h"
 #include "paddle/fluid/operators/math/selected_rows_functor.h"
 #include "paddle/fluid/platform/for_range.h"
+#include "paddle/fluid/platform/profiler.h"
 
 namespace paddle {
 namespace operators {
@@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel<T> {
                           beta2_pow_out->numel()));
 
     if (grad_var->IsType<framework::LoDTensor>()) {
-      auto* grad = ctx.Input<LoDTensor>("Grad");
+      T beta1_p = beta1_pow->data<T>()[0];
+      T beta2_p = beta2_pow->data<T>()[0];
 
-      AdamFunctor<T, CPUAdam> functor(
-          beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
-          mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
-          mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
-          lr->data<T>(), grad->data<T>(), param->data<T>(),
-          param_out->mutable_data<T>(ctx.GetPlace()));
-      functor(param->numel());
       if (!use_global_beta_pow) {
         beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
             beta1 * beta1_pow->data<T>()[0];
         beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
             beta2 * beta2_pow->data<T>()[0];
       }
+
+      auto* grad = ctx.Input<LoDTensor>("Grad");
+
+      T* param_out_ptr = param_out->mutable_data<T>(ctx.GetPlace());
+      T* mom1_out_ptr = mom1_out->mutable_data<T>(ctx.GetPlace());
+      T* mom2_out_ptr = mom2_out->mutable_data<T>(ctx.GetPlace());
+
+      T learning_rate = lr->data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p));
+      T eps = epsilon * sqrt(1 - beta2_p);
+
+      jit::adam_attr_t attr(beta1, beta2);
+      int64_t numel = param->numel();
+
+      const T* param_ptr = param->data<T>();
+      const T* mom1_ptr = mom1->data<T>();
+      const T* mom2_ptr = mom2->data<T>();
+      const T* grad_ptr = grad->data<T>();
+
+      auto adam =
+          jit::KernelFuncs<jit::AdamTuple<T>, platform::CPUPlace>::Cache().At(
+              attr);
+
+      static constexpr int64_t chunk_size = 512;
+
+#ifdef PADDLE_WITH_MKLML
+#pragma omp parallel for
+#endif
+      for (int64_t i = 0; i < numel / chunk_size; ++i) {
+        const int64_t offset = i * chunk_size;
+        adam(beta1, beta2, -learning_rate, eps, chunk_size, grad_ptr + offset,
+             mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
+             mom1_out_ptr + offset, mom2_out_ptr + offset,
+             param_out_ptr + offset);
+      }
+
+      if (numel % chunk_size != 0) {
+        const int64_t offset = (numel / chunk_size) * chunk_size;
+        const int64_t tail_numel = numel % chunk_size;
+        adam(beta1, beta2, -learning_rate, eps, tail_numel, grad_ptr + offset,
+             mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
+             mom1_out_ptr + offset, mom2_out_ptr + offset,
+             param_out_ptr + offset);
+      }
     } else if (grad_var->IsType<pten::SelectedRows>()) {
       auto* grad = ctx.Input<pten::SelectedRows>("Grad");
       if (grad->rows().size() == 0) {
diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py
index a06f0d390e517d..ecac22553cbcda 100644
--- a/python/paddle/fluid/tests/unittests/test_adam_op.py
+++ b/python/paddle/fluid/tests/unittests/test_adam_op.py
@@ -69,15 +69,19 @@ def test_check_output(self):
 
 
 class TestAdamOp2(OpTest):
+    def set_shape(self):
+        self.shape = (102, 105)
+
     def setUp(self):
         '''Test Adam Op with supplied attributes
         '''
         self.op_type = "adam"
-        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
-        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
-        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
+        self.set_shape()
+        param = np.random.uniform(-1, 1, self.shape).astype("float32")
+        grad = np.random.uniform(-1, 1, self.shape).astype("float32")
+        moment1 = np.random.uniform(-1, 1, self.shape).astype("float32")
         # The second moment is positive
-        moment2 = np.random.random((102, 105)).astype("float32")
+        moment2 = np.random.random(self.shape).astype("float32")
 
         learning_rate = 0.001
         beta1 = 0.9
@@ -113,6 +117,11 @@ def test_check_output(self):
         self.check_output()
 
 
+class TestAdamOnlyTailOp(TestAdamOp2):
+    def set_shape(self):
+        self.shape = (3)
+
+
 class TestAdamOpMultipleSteps(OpTest):
     def setUp(self):
         '''Test Adam Operator with supplied attributes