-
Notifications
You must be signed in to change notification settings - Fork 5
issue/142:添加rms_norm算子测例 #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JYMiracle305
wants to merge
2
commits into
main
Choose a base branch
from
issue/142
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include "ops.hpp" | ||
#include "utils.hpp" | ||
#include <infinirt.h> | ||
#include <iomanip> | ||
#include <iostream> | ||
|
||
namespace infiniop_test::rms_norm { | ||
struct Test::Attributes { | ||
float epsilon; | ||
std::shared_ptr<Tensor> input; | ||
std::shared_ptr<Tensor> weight; | ||
std::shared_ptr<Tensor> ans; | ||
std::shared_ptr<Tensor> result; | ||
}; | ||
|
||
std::shared_ptr<Test> Test::build( | ||
std::unordered_map<std::string, std::vector<uint8_t>> attributes, | ||
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors, | ||
double rtol, double atol) { | ||
auto test = std::shared_ptr<Test>(new Test(rtol, atol)); | ||
test->_attributes = new Attributes(); | ||
|
||
if (attributes.find("epsilon") == attributes.end() | ||
|| tensors.find("input") == tensors.end() | ||
|| tensors.find("weight") == tensors.end() | ||
|| tensors.find("ans") == tensors.end() | ||
|| tensors.find("result") == tensors.end()) { | ||
throw std::runtime_error("Invalid Test: Missing attributes or tensors"); | ||
} | ||
|
||
test->_attributes->epsilon = *reinterpret_cast<float *>(attributes["epsilon"].data()); | ||
|
||
test->_attributes->input = tensors["input"]; | ||
test->_attributes->weight = tensors["weight"]; | ||
test->_attributes->ans = tensors["ans"]; | ||
test->_attributes->result = tensors["result"]; | ||
|
||
return test; | ||
} | ||
|
||
std::shared_ptr<infiniop_test::Result> Test::run( | ||
infiniopHandle_t handle, infiniDevice_t device, int device_id, | ||
size_t warm_ups, size_t iterations) { | ||
|
||
infiniopRMSNormDescriptor_t op_desc; | ||
CHECK_OR(infiniopCreateRMSNormDescriptor(handle, &op_desc, | ||
_attributes->result->desc(), | ||
_attributes->input->desc(), | ||
_attributes->weight->desc(), | ||
_attributes->epsilon), | ||
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create RMSNorm descriptor")); | ||
|
||
auto input = _attributes->input->to(device, device_id); | ||
auto weight = _attributes->weight->to(device, device_id); | ||
auto result = _attributes->result->to(device, device_id); | ||
|
||
size_t workspace_size; | ||
CHECK_OR(infiniopGetRMSNormWorkspaceSize(op_desc, &workspace_size), | ||
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size")); | ||
void *workspace = nullptr; | ||
if (workspace_size > 0) { | ||
CHECK_OR(infinirtMalloc(&workspace, workspace_size), | ||
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace")); | ||
} | ||
|
||
CHECK_OR(infiniopRMSNorm(op_desc, | ||
workspace, workspace_size, | ||
result->data(), | ||
input->data(), | ||
weight->data(), | ||
nullptr), | ||
return TEST_FAILED(OP_EXECUTION_FAILED, "RMSNorm execution failed")); | ||
|
||
try { | ||
allClose(result, _attributes->ans, _rtol, _atol); | ||
} catch (const std::exception &e) { | ||
return TEST_FAILED(RESULT_INCORRECT, e.what()); | ||
} | ||
|
||
double elapsed_time = 0.; | ||
|
||
elapsed_time = benchmark( | ||
[=]() { | ||
infiniopRMSNorm(op_desc, | ||
workspace, workspace_size, | ||
result->data(), | ||
input->data(), | ||
weight->data(), | ||
nullptr); | ||
}, | ||
warm_ups, iterations); | ||
|
||
if (workspace != nullptr) { | ||
infinirtFree(workspace); | ||
} | ||
|
||
return TEST_PASSED(elapsed_time); | ||
} | ||
|
||
std::vector<std::string> Test::attribute_names() { | ||
return {"epsilon"}; | ||
} | ||
|
||
std::vector<std::string> Test::tensor_names() { | ||
return {"input", "weight", "ans", "result"}; | ||
} | ||
|
||
std::string Test::toString() const { | ||
std::ostringstream oss; | ||
oss << op_name() << std::endl; | ||
oss << "- epsilon=" << _attributes->epsilon << std::endl; | ||
oss << "- input: " << _attributes->input->info() << std::endl; | ||
oss << "- weight: " << _attributes->weight->info() << std::endl; | ||
oss << "- result: " << _attributes->result->info() << std::endl; | ||
oss << std::scientific << std::setprecision(2); | ||
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; | ||
return oss.str(); | ||
} | ||
|
||
Test::~Test() { | ||
delete _attributes; | ||
} | ||
|
||
} // namespace infiniop_test::rms_norm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from ast import List | ||
import numpy as np | ||
import gguf | ||
from typing import Optional | ||
from numpy.lib.stride_tricks import as_strided | ||
|
||
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides | ||
|
||
def create_non_contiguous(shape, dtype, stride_scale=2): | ||
expanded_shape = (shape[0] * stride_scale,) + shape[1:] | ||
buffer = np.random.uniform(-1.0, 1.0, expanded_shape).astype(dtype) * 0.001 | ||
|
||
new_strides = (buffer.strides[0] * stride_scale,) + buffer.strides[1:] | ||
|
||
return as_strided(buffer, shape=shape, strides=new_strides) | ||
|
||
def random_tensor(shape: tuple, dtype: np.dtype) -> np.ndarray: | ||
return np.random.uniform(-1.0, 1.0, shape).astype(dtype) * 0.001 | ||
|
||
def rms_norm(input: np.ndarray, weight: np.ndarray, epsilon: float) -> np.ndarray: | ||
""" | ||
使用numpy计算rms_norm结果 | ||
Args: | ||
input: 输入张量, 维度为2, 形状为 [..., hidden_size] | ||
weight: 缩放权重, 形状为 [hidden_size] | ||
epsilon: 避免除零的小常数 | ||
Returns: | ||
输出张量, 形状与 input 相同 | ||
""" | ||
squared = input ** 2 | ||
mean = np.mean(squared, axis=-1, keepdims=True) | ||
rms = np.sqrt(mean + epsilon) | ||
|
||
normalized = input / rms | ||
return normalized * weight | ||
|
||
class RMSNormTestCase(InfiniopTestCase): | ||
def __init__( | ||
self, | ||
input_shape: tuple, | ||
weight_shape: tuple, | ||
atype: np.dtype, | ||
wtype: np.dtype, | ||
epsilon: float = 1e-5, | ||
input_non_contiguous: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接传入input stride吧 |
||
input_stride_scale: int = 2, | ||
): | ||
super().__init__("rms_norm") | ||
if input_non_contiguous: | ||
self.input = create_non_contiguous(input_shape, atype, input_stride_scale) | ||
else: | ||
self.input = random_tensor(input_shape, atype) | ||
self.weight = random_tensor(weight_shape, wtype) | ||
self.epsilon = epsilon | ||
self.result = np.zeros_like(self.input) | ||
self.ans = rms_norm(self.input, self.weight, self.epsilon).astype(atype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 答案用f64计算 |
||
|
||
def write_test(self, test_writer: "InfiniopTestWriter"): | ||
super().write_test(test_writer) | ||
test_writer.add_float32(test_writer.gguf_key("epsilon"), self.epsilon) | ||
test_writer.add_tensor( | ||
test_writer.gguf_key("input"), | ||
self.input, | ||
raw_dtype=np_dtype_to_ggml(self.input.dtype), | ||
) | ||
test_writer.add_tensor( | ||
test_writer.gguf_key("weight"), | ||
self.weight, | ||
raw_dtype=np_dtype_to_ggml(self.weight.dtype), | ||
) | ||
test_writer.add_tensor( | ||
test_writer.gguf_key("ans"), | ||
self.ans, | ||
raw_dtype=np_dtype_to_ggml(self.ans.dtype), | ||
) | ||
test_writer.add_tensor( | ||
test_writer.gguf_key("result"), | ||
self.result, | ||
raw_dtype=np_dtype_to_ggml(self.result.dtype), | ||
) | ||
|
||
if __name__ == "__main__": | ||
test_writer = InfiniopTestWriter("rms_norm.gguf") | ||
|
||
test_cases = [ | ||
RMSNormTestCase( | ||
input_shape=(2, 256), | ||
weight_shape=(256,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(8, 1024), | ||
weight_shape=(1024,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(1, 768), | ||
weight_shape=(768,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(8, 256), | ||
weight_shape=(256,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(2, 256), | ||
weight_shape=(256,), | ||
atype=np.float16, | ||
wtype=np.float16, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float16, | ||
wtype=np.float16, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float16, | ||
wtype=np.float16, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float16, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float16, | ||
wtype=np.float32, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float32, | ||
wtype=np.float32, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float16, | ||
wtype=np.float16, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float16, | ||
wtype=np.float16, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(4, 512), | ||
weight_shape=(512,), | ||
atype=np.float16, | ||
wtype=np.float32, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
RMSNormTestCase( | ||
input_shape=(500, 4096), | ||
weight_shape=(4096,), | ||
atype=np.float16, | ||
wtype=np.float32, | ||
input_non_contiguous=True, | ||
input_stride_scale=2, | ||
), | ||
] | ||
|
||
test_writer.add_tests(test_cases) | ||
test_writer.save() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
第一维不连续的情况怎么测?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增生成第一维不连续张量的参数