Skip to content

Commit

Permalink
[Add] bad statement check (PaddlePaddle#57421)
Browse files Browse the repository at this point in the history
* [Add] bad statement check

* [Change] fluid can in comment

* [Fix] fix regex

* [Change] add more test for TestResult

* [Fix] default value

* [Change] restore creation.py
  • Loading branch information
megemini authored and Frida-a committed Oct 14, 2023
1 parent 9be5864 commit d0a9217
Show file tree
Hide file tree
Showing 3 changed files with 492 additions and 88 deletions.
164 changes: 99 additions & 65 deletions tools/sampcd_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import collections
import functools
import multiprocessing
import os
Expand Down Expand Up @@ -219,13 +220,76 @@ def parse_directive(self, docstring):
return docstring, float(self._timeout)


class BadStatement:
msg: str = ''

def check(self, docstring: str) -> bool:
raise NotImplementedError


class Fluid(BadStatement):
msg = 'Please do NOT use `fluid` api.'

_pattern = re.compile(
r"""
(\>{3}|\.{3})
(?P<comment>.*)
import
.*
(\bfluid\b)
""",
re.X,
)

def check(self, docstring):
for match_obj in self._pattern.finditer(docstring):
comment = match_obj.group('comment').strip()
if not comment.startswith('#'):
return True

return False


class SkipNoReason(BadStatement):
msg = 'Please add sample code skip reason.'

_pattern = re.compile(
r"""
\#
\s*
(x?doctest:)
\s*
[+]SKIP
(?P<reason>.*)
""",
re.X,
)

def check(self, docstring):
for match_obj in self._pattern.finditer(docstring):
reason = (
match_obj.group('reason').strip().strip('(').strip(')').strip()
)
if not reason:
return True

return False


class Xdoctester(DocTester):
"""A Xdoctest doctester."""

directives: typing.Dict[str, typing.Tuple[typing.Type[Directive], ...]] = {
'timeout': (TimeoutDirective, TEST_TIMEOUT)
}

bad_statements: typing.Dict[
str, typing.Tuple[typing.Type[BadStatement], ...]
] = {
'fluid': (Fluid,),
'skip': (SkipNoReason,),
}

def __init__(
self,
debug=False,
Expand Down Expand Up @@ -330,8 +394,31 @@ def prepare(self, test_capacity: set):

self._test_capacity = test_capacity

def _check_bad_statements(self, docstring: str) -> typing.Set[str]:
bad_results = set()
for name, statement_cls in self.bad_statements.items():
if statement_cls[0](*statement_cls[1:]).check(docstring):
bad_results.add(name)

return bad_results

def run(self, api_name: str, docstring: str) -> typing.List[TestResult]:
"""Run the xdoctest with a docstring."""
# check bad statements
bad_results = self._check_bad_statements(docstring)
if bad_results:
for name in bad_results:
logger.warning(
"%s >>> %s", api_name, str(self.bad_statements[name][0].msg)
)

return [
TestResult(
name=api_name,
badstatement=True,
)
]

# parse global directive
docstring, directives = self._parse_directive(docstring)

Expand Down Expand Up @@ -439,11 +526,8 @@ def _execute_with_queue(self, queue, examples_to_test, examples_nocode):
queue.put(self._execute(examples_to_test, examples_nocode))

def print_summary(self, test_results, whl_error=None):
summary_success = []
summary_failed = []
summary_skiptest = []
summary_timeout = []
summary_nocodes = []
summary = collections.defaultdict(list)
is_fail = False

logger.warning("----------------Check results--------------------")
logger.warning(">>> Sample code test capacity: %s", self._test_capacity)
Expand Down Expand Up @@ -472,70 +556,20 @@ def print_summary(self, test_results, whl_error=None):

else:
for test_result in test_results:
if not test_result.nocode:
if test_result.passed:
summary_success.append(test_result.name)

if test_result.skipped:
summary_skiptest.append(test_result.name)

if test_result.failed:
summary_failed.append(test_result.name)

if test_result.timeout:
summary_timeout.append(
{
'api_name': test_result.name,
'run_time': test_result.time,
}
)
else:
summary_nocodes.append(test_result.name)

if len(summary_success):
logger.info(
">>> %d sample codes ran success in env: %s",
len(summary_success),
self._test_capacity,
)
logger.info('\n'.join(summary_success))
summary[test_result.state].append(test_result)
if test_result.state.is_fail:
is_fail = True

if len(summary_skiptest):
logger.warning(
">>> %d sample codes skipped in env: %s",
len(summary_skiptest),
self._test_capacity,
)
logger.warning('\n'.join(summary_skiptest))

if len(summary_nocodes):
logger.error(
">>> %d apis don't have sample codes or could not run test in env: %s",
len(summary_nocodes),
self._test_capacity,
)
logger.error('\n'.join(summary_nocodes))
summary = sorted(summary.items(), key=lambda x: x[0].order)

if len(summary_timeout):
logger.error(
">>> %d sample codes ran timeout or error in env: %s",
len(summary_timeout),
self._test_capacity,
)
for _result in summary_timeout:
logger.error(
f"{_result['api_name']} - more than {_result['run_time']}s"
)

if len(summary_failed):
logger.error(
">>> %d sample codes ran failed in env: %s",
len(summary_failed),
self._test_capacity,
for result_cls, result_list in summary:
logging_msg = result_cls.msg(
len(result_list), self._test_capacity
)
logger.error('\n'.join(summary_failed))
result_cls.logger(logging_msg)
result_cls.logger('\n'.join([str(r) for r in result_list]))

if summary_failed or summary_timeout or summary_nocodes:
if is_fail:
logger.warning(
">>> Mistakes found in sample codes in env: %s!",
self._test_capacity,
Expand Down
163 changes: 155 additions & 8 deletions tools/sampcd_processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import argparse
import dataclasses
import inspect
import logging
import os
Expand Down Expand Up @@ -50,18 +49,166 @@
TEST_TIMEOUT = 10


@dataclasses.dataclass
class Result:
# name/key for result
name: str = ''

# default value
default: bool = False

# is failed result or not
is_fail: bool = False

# logging
logger: typing.Callable = logger.info

# logging print order(not logging level, just for convenient)
order: int = 0

@classmethod
def msg(cls, count: int, env: typing.Set) -> str:
"""Message for logging with api `count` and running `env`."""
raise NotImplementedError


class MetaResult(type):
"""A meta class to record `Result` subclasses."""

__slots__ = ()

# hold result cls
__cls_map = {}

# result added order
__order = 0

def __new__(
mcs,
name: str,
bases: typing.Tuple[type, ...],
namespace: typing.Dict[str, typing.Any],
) -> type:
cls = super().__new__(mcs, name, bases, namespace)
if issubclass(cls, Result):
# set cls order as added to Meta
cls.order = mcs.__order
mcs.__order += 1

# put cls into Meta's map
mcs.__cls_map[namespace.get('name')] = cls

return cls

@classmethod
def get(mcs, name: str) -> type:
return mcs.__cls_map.get(name)

@classmethod
def cls_map(mcs) -> typing.Dict[str, Result]:
return mcs.__cls_map


class Passed(Result, metaclass=MetaResult):
name = 'passed'
is_fail = False

@classmethod
def msg(cls, count, env):
return f">>> {count} sample codes ran success in env: {env}"


class Skipped(Result, metaclass=MetaResult):
name = 'skipped'
is_fail = False
logger = logger.warning

@classmethod
def msg(cls, count, env):
return f">>> {count} sample codes skipped in env: {env}"


class Failed(Result, metaclass=MetaResult):
name = 'failed'
is_fail = True
logger = logger.error

@classmethod
def msg(cls, count, env):
return f">>> {count} sample codes ran failed in env: {env}"


class NoCode(Result, metaclass=MetaResult):
name = 'nocode'
is_fail = True
logger = logger.error

@classmethod
def msg(cls, count, env):
return f">>> {count} apis don't have sample codes or could not run test in env: {env}"


class Timeout(Result, metaclass=MetaResult):
name = 'timeout'
is_fail = True
logger = logger.error

@classmethod
def msg(cls, count, env):
return f">>> {count} sample codes ran timeout or error in env: {env}"


class BadStatement(Result, metaclass=MetaResult):
name = 'badstatement'
is_fail = True
logger = logger.error

@classmethod
def msg(cls, count, env):
return (
f">>> {count} bad statements detected in sample codes in env: {env}"
)


class TestResult:
name: str
nocode: bool = False
passed: bool = False
skipped: bool = False
failed: bool = False
timeout: bool = False
name: str = ""
time: float = float('inf')
test_msg: str = ""
extra_info: str = ""

# there should be only one result be True.
__unique_state: Result = None

def __init__(self, **kwargs) -> None:
# set all attr from metaclass
for result_name, result_cls in MetaResult.cls_map().items():
setattr(self, result_name, result_cls.default)

# overwrite attr from kwargs
for name, value in kwargs.items():
# check attr name
if not (hasattr(self, name) or name in MetaResult.cls_map()):
raise KeyError('`{}` is not a valid result type.'.format(name))

setattr(self, name, value)

if name in MetaResult.cls_map() and value:
if self.__unique_state is not None:
logger.warning('Only one result state should be True.')

self.__unique_state = MetaResult.get(name)

if self.__unique_state is None:
logger.warning('Default result will be set to FAILED!')
setattr(self, Failed.name, True)
self.__unique_state = Failed

@property
def state(self) -> Result:
return self.__unique_state

def __str__(self) -> str:
return '{}, running time: {:.3f}s'.format(self.name, self.time)


class DocTester:
"""A DocTester can be used to test the codeblock from the API's docstring.
Expand Down
Loading

0 comments on commit d0a9217

Please sign in to comment.