Skip to content

Commit

Permalink
Revert "use flat_hash_map and small_vector in kernel factory"
Browse files Browse the repository at this point in the history
This reverts commit 2309149.
  • Loading branch information
chenwhql committed Oct 15, 2021
1 parent 6ce92e5 commit e0322d5
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions paddle/tcmpt/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

#include <ostream>
#include <string>
#include <unordered_map>
#include <utility>

#include "paddle/tcmpt/core/backend.h"
#include "paddle/tcmpt/core/dtype.h"
#include "paddle/tcmpt/core/kernel_def.h"
#include "paddle/tcmpt/core/layout.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -210,30 +209,25 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index));
}

const paddle::SmallVector<TensorArgDef>& input_defs() const {
return input_defs_;
}
const std::vector<TensorArgDef>& input_defs() const { return input_defs_; }

const paddle::SmallVector<TensorArgDef>& output_defs() const {
return output_defs_;
}
const std::vector<TensorArgDef>& output_defs() const { return output_defs_; }

const paddle::SmallVector<AttributeArgDef>& attribute_defs() const {
const std::vector<AttributeArgDef>& attribute_defs() const {
return attribute_defs_;
}

paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; }
std::vector<TensorArgDef>& input_defs() { return input_defs_; }

paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; }
std::vector<TensorArgDef>& output_defs() { return output_defs_; }

paddle::SmallVector<AttributeArgDef>& attribute_defs() {
return attribute_defs_;
}
std::vector<AttributeArgDef>& attribute_defs() { return attribute_defs_; }

private:
paddle::SmallVector<TensorArgDef> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}};
// TODO(chenweihang): replaced by paddle::small_vector
std::vector<TensorArgDef> input_defs_{{}};
std::vector<TensorArgDef> output_defs_{{}};
std::vector<AttributeArgDef> attribute_defs_{{}};
};

class Kernel {
Expand Down Expand Up @@ -269,10 +263,10 @@ class Kernel {
class KernelFactory {
public:
// replaced by paddle::flat_hash_map later
using KernelMap = paddle::flat_hash_map<
KernelName,
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>,
KernelName::Hash>;
using KernelMap =
std::unordered_map<KernelName,
std::unordered_map<KernelKey, Kernel, KernelKey::Hash>,
KernelName::Hash>;

static KernelFactory& Instance();

Expand Down

0 comments on commit e0322d5

Please sign in to comment.