Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[e2e] add compatible test #381

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions tests/compatibilty_test/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import brt
from brt.utils import brt_dtype_to_torch_dtype

import torch
import numpy as np
import os
import re

from reporting import TestResult


class BRTBackend:

def __init__(self, device, brt_file_path):
_stream = None
self.device = None
if device == "CPU":
self.session = brt.Session(device=device.upper(), )
self.device = "cpu"
_stream = None
else:
raise NotImplementedError(
f"Compatible test for {device} not implement")

self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)

def _check(self, result, golden, atol=1e-06):
return torch.allclose(result, golden, atol=atol)

def _generate_torch_outputs(self):
outputs = []
for offset in self.session.get_output_arg_offsets():
outputs.append(
torch.empty(self.session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(
self.session.get_data_type(offset)),
device=self.device))
return outputs

def compare(self, inputs, goldens):
outputs = self._generate_torch_outputs()
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
assert len(outputs) == len(goldens)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(
arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(
ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
self.req.sync()
return all(self._check(o, g) for o, g in zip(outputs, goldens))


def run_and_check_mlir(target, name, inp_files, out_files, byre_file):

_device = None
if target == "cpu":
_device = "CPU"

brt_backend = BRTBackend(device=_device, brt_file_path=byre_file)

cmp_res = []
for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)):
inp = np.load(input_file, allow_pickle=True)
inp = [
torch.from_numpy(inp[f]).contiguous().to(_device.lower())
for f in inp.files
]
tgt = np.load(target_file, allow_pickle=True)
tgt = [
torch.from_numpy(tgt[f]).contiguous().to(_device.lower())
for f in tgt.files
]
if brt_backend.compare(inp, tgt):
cmp_res.append(TestResult(name + str(idx), numerical_error=None))
else:
cmp_res.append(
TestResult(
name + str(idx),
numerical_error=
f"input is {input_file}, output not match {target_file}"))

return cmp_res
86 changes: 86 additions & 0 deletions tests/compatibilty_test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#!/usr/bin/python3
import argparse
import os
import sys
import json

from reporting import report_results

from execute import run_and_check_mlir
"""
Usage:
This directory implements the code for compatibilty test framework. One should pass a test dir which contains:
(1) subdirs for each tese case and json conf file named `testcase.json`
(2) byre compilation artifacts named as {model_name}/{model_name}.rt.mlir
(3) several inputs and goldens named as inputs.{num}.npz and outputs.{num}.npz
"""


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--testdir",
type=str,
default=None,
help="Directory has test cases",
)
args = parser.parse_args()
return args


def run(testdir):

def extract_name_from_tesrdir(testdir):
return os.path.basename(testdir)

result = []
conf_file = os.path.join(testdir, "testcase.json")
if not os.path.exists(conf_file):
raise RuntimeError(f"test case config file {conf_file} not found")
with open(conf_file, "r", encoding='utf-8') as f:
json_conf = json.load(f)
for target, data in json_conf.items():
for byre_version, cases in data.items():
for name, files in cases.items():
input_files = files["golden_inputs"]
input_files = [os.path.join(testdir, f) for f in input_files]
golden_files = files["golden_outputs"]
golden_files = [os.path.join(testdir, f) for f in golden_files]
byre_file = files["brt_entry_file"]
byre_file = os.path.join(testdir, byre_file)
if len(input_files) != len(golden_files):
raise RuntimeError(
f"num of inouts({len(input_files)}) and goldens({len(golden_files)}) not eq in {name}"
)
if not os.path.exists(byre_file):
raise RuntimeError(f"byre file{byre_file} not found")
result += run_and_check_mlir(target, name, input_files,
golden_files, byre_file)
return result


def main():
args = parse_args()

results = run(args.testdir)

failed = report_results(results)
sys.exit(1 if failed else 0)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions tests/compatibilty_test/reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import NamedTuple, Optional, List


class TestResult(NamedTuple):
unique_name: str
numerical_error: Optional[str]


def report_results(results: List[TestResult]):
fail_case = []
pass_case = []
for result in results:
if result.numerical_error is not None:
fail_case.append([
result.unique_name, "numerical failed: " + result.unique_name +
"\n" + result.numerical_error
])
else:
pass_case.append(result)
pass_case.sort(key=lambda x: x.unique_name)
fail_case.sort(key=lambda x: x[0])

print(f"\n****** PASS tests - {len(pass_case)} tests")
for test in pass_case:
print(test.unique_name, " --- PASS")
for test in fail_case:
print(test[1])
print(f"\n****** FAILED tests - {len(fail_case)} tests")
for test in fail_case:
print(test[0])
return len(fail_case) > 0
Loading