Skip to content

Commit

Permalink
feat(fuzzer): implement fuzzing of linear algebra operators
Browse files Browse the repository at this point in the history
  • Loading branch information
rayanht committed Apr 17, 2022
1 parent 6846733 commit 904d0dd
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 12 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,13 @@ void main () {
| OpSMod | :white_check_mark: |
| OpFRem | :white_check_mark: |
| OpFMod | :white_check_mark: |
| OpVectorTimesScalar | :red_circle: |
| OpMatrixTimesScalar | :red_circle: |
| OpVectorTimesMatrix | :red_circle: |
| OpMatrixTimesVector | :red_circle: |
| OpMatrixTimesMatrix | :red_circle: |
| OpOuterProduct | :red_circle: |
| OpDot | :red_circle: |
| OpVectorTimesScalar | :white_check_mark: |
| OpMatrixTimesScalar | :white_check_mark: |
| OpVectorTimesMatrix | :white_check_mark: |
| OpMatrixTimesVector | :white_check_mark: |
| OpMatrixTimesMatrix | :white_check_mark: |
| OpOuterProduct | :white_check_mark: |
| OpDot | :white_check_mark: |
| OpIAddCarry | :red_circle: |
| OpISubBorrow | :red_circle: |
| OpUMulExtended | :red_circle: |
Expand Down
2 changes: 1 addition & 1 deletion src/fuzzing_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def gen_shader(self) -> SPIRVShader:
name="main",
interfaces=interfaces,
)
capabilities = [OpCapability(capability=Capability.Shader)]
capabilities = [OpCapability(capability=Capability.Shader), OpCapability(capability=Capability.Matrix)]
memory_model = OpMemoryModel(
addressing_model=AddressingModel.Logical, memory_model=MemoryModel.GLSL450
)
Expand Down
152 changes: 149 additions & 3 deletions src/operators/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Optional,
TypeVar,
)
from src import Signed, Statement, Unsigned
from src import OpCode, Signed, Statement, Unsigned

from src.constants import Constant

if TYPE_CHECKING:
from src.context import Context
from src.operators import BinaryOperatorFuzzMixin, UnaryOperatorFuzzMixin
from src.predicates import (
HasValidBaseType,
HasValidBaseTypeAndSign,
HasValidType,
HaveSameTypeLength,
IsMatrixType,
IsScalarFloat,
IsValidArithmeticOperand,
IsVectorType,
)

from src.types.concrete_types import (
OpTypeFloat,
OpTypeInt,
OpTypeMatrix,
OpTypeVector,
Type,
)

Expand Down Expand Up @@ -148,5 +160,139 @@ class OpFMod(
...


# class OpVectorTimesScalar(BinaryOperatorFuzzMixin, BinaryArithmeticOperator[OpTypeVector, None, None, None]):
# ...
# The following operators override the fuzzing logic rather than relying on the mixin
# This is because trying to encompass their logic in the mixin would be way too complex:

class OpVectorTimesScalar(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeVector = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
operand1 = context.get_random_operand(
lambda x: IsVectorType(x) and HasValidBaseType(x, OpTypeFloat)
)
if not operand1:
return []
operand2 = context.get_random_operand(IsScalarFloat)
if not operand2:
return []
self.type = operand1.type
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpMatrixTimesScalar(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeMatrix = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
operand1 = context.get_random_operand(
lambda x: IsMatrixType(x) and HasValidBaseType(x, OpTypeFloat)
)
if not operand1:
return []
operand2 = context.get_random_operand(IsScalarFloat)
if not operand2:
return []
self.type = operand1.type
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpVectorTimesMatrix(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeVector = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
operand1 = context.get_random_operand(
lambda x: IsVectorType(x) and HasValidBaseType(x, OpTypeFloat)
)
if not operand1:
return []
operand2 = context.get_random_operand(lambda x: IsMatrixType(x) and HasValidBaseType(x, OpTypeFloat) and HaveSameTypeLength(x, operand1))
if not operand2:
return []
self.type = OpTypeVector()
self.type.type = operand1.get_base_type()
self.type.size = len(operand2.type.type)
context.add_to_tvc(self.type)
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpMatrixTimesVector(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeVector = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
operand1 = context.get_random_operand(lambda x: IsMatrixType(x) and HasValidBaseType(x, OpTypeFloat))
if not operand1:
return []
operand2 = context.get_random_operand(
lambda x: IsVectorType(x) and HasValidBaseType(x, OpTypeFloat) and HaveSameTypeLength(x, operand1)
)
if not operand2:
return []
self.type = operand1.type.type
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpMatrixTimesMatrix(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeMatrix = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
operand1 = context.get_random_operand(lambda x: IsMatrixType(x) and HasValidBaseType(x, OpTypeFloat))
if not operand1:
return []
# Same number of columns that operand1 has rows
operand2 = context.get_random_operand(lambda x: IsMatrixType(x) and HasValidBaseType(x, OpTypeFloat) and HaveSameTypeLength(x, operand1.type))
if not operand2:
return []
self.type = OpTypeMatrix()
self.type.type = operand1.type.type
self.type.size = len(operand2.type)
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpOuterProduct(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeMatrix = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
predicate = lambda x: IsVectorType(x) and HasValidBaseType(x, OpTypeFloat)
operand1 = context.get_random_operand(predicate)
if not operand1:
return []
operand2 = context.get_random_operand(predicate)
self.type = OpTypeMatrix()
self.type.type = operand1.type
self.type.size = len(operand2.type)
context.add_to_tvc(self.type)
self.operand1 = operand1
self.operand2 = operand2
return [self]

class OpDot(BinaryArithmeticOperator[None, None, None, None]):
type: OpTypeFloat = None
operand1: Operand = None
operand2: Operand = None

def fuzz(self, context: "Context") -> list[OpCode]:
predicate = lambda x: IsVectorType(x) and HasValidBaseType(x, OpTypeFloat)
operand1 = context.get_random_operand(predicate)
if not operand1:
return []
operand2 = context.get_random_operand(predicate, operand1)
self.type = operand1.get_base_type()
self.operand1 = operand1
self.operand2 = operand2
return [self]
4 changes: 3 additions & 1 deletion src/predicates.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from src.enums import StorageClass
from src.types.abstract_types import ArithmeticType
from src.types.concrete_types import OpTypeBool, OpTypeInt, OpTypeVector
from src.types.concrete_types import OpTypeBool, OpTypeFloat, OpTypeInt, OpTypeMatrix, OpTypeVector

IsVectorType = lambda x: isinstance(x.type, OpTypeVector)
IsMatrixType = lambda x: isinstance(x.type, OpTypeMatrix)
IsScalarInteger = lambda x: isinstance(x.type, OpTypeInt)
IsScalarFloat = lambda x: isinstance(x.type, OpTypeFloat)
HasValidBaseType = lambda x, target_type: isinstance(x.get_base_type(), target_type)
HasValidSign = (
lambda x, signed: x.get_base_type().signed == signed if signed is not None else True
Expand Down

0 comments on commit 904d0dd

Please sign in to comment.