Skip to content

Commit

Permalink
Test for vector shifting
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Dec 3, 2024
1 parent 156a569 commit fb08d2a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 7 deletions.
90 changes: 86 additions & 4 deletions test/test_shifter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import random
import pytest
from collections.abc import Callable, Iterable
from typing import Any
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from amaranth import *
from amaranth_types.types import ValueLike
from amaranth.lib import data
from amaranth_types.types import ShapeCastable, ShapeLike, ValueLike
from transactron.utils import assign
from transactron.utils.amaranth_ext.shifter import *
from transactron.testing import TestCaseWithSimulator, TestbenchContext


class ShifterCircuit(Elaboratable):
def __init__(
self, shift_fun: Callable[[ValueLike, ValueLike], Value], width, shift_kwargs: Iterable[tuple[str, Any]] = ()
self,
shift_fun: Callable[[ValueLike, ValueLike], Value],
width: int,
shift_kwargs: Iterable[tuple[str, Any]] = (),
):
self.input = Signal(width)
self.output = Signal(width)
Expand Down Expand Up @@ -62,3 +67,80 @@ async def test_process(sim: TestbenchContext):

with self.run_simulation(dut, add_transaction_module=False) as sim:
sim.add_testbench(test_process)


class VecShifterCircuit(Elaboratable):
def __init__(
self,
shift_fun: Callable[[Sequence, ValueLike], Sequence],
shape: ShapeLike,
width: int,
shift_kwargs: Iterable[tuple[str, Any]] = (),
):
self.input = Signal(data.ArrayLayout(shape, width))
self.output = Signal(data.ArrayLayout(shape, width))
self.offset = Signal(range(width + 1))
self.shift_fun = shift_fun
self.kwargs = dict(shift_kwargs)

def elaborate(self, platform):
m = Module()

m.d.comb += assign(self.output, self.shift_fun(cast(Sequence, self.input), self.offset, **self.kwargs))

return m


class TestVecShifter(TestCaseWithSimulator):
@pytest.mark.parametrize(
"shape",
[
4,
data.ArrayLayout(2, 2),
],
)
@pytest.mark.parametrize(
"shift_fun, shift_kwargs, test_fun",
[
(shift_vec_left, lambda mkc: [], lambda val, offset, mkc: [mkc(0)] * offset + val[: len(val) - offset]),
(shift_vec_right, lambda mkc: [], lambda val, offset, mkc: val[offset:] + [mkc(0)] * offset),
(
shift_vec_left,
lambda mkc: [("placeholder", mkc(0))],
lambda val, offset, mkc: [mkc(0)] * offset + val[: len(val) - offset],
),
(
shift_vec_right,
lambda mkc: [("placeholder", mkc(0))],
lambda val, offset, mkc: val[offset:] + [mkc(0)] * offset,
),
(
rotate_vec_left,
lambda mkc: [],
lambda val, offset, mkc: val[len(val) - offset :] + val[: len(val) - offset],
),
(rotate_vec_right, lambda mkc: [], lambda val, offset, mkc: val[offset:] + val[:offset]),
],
)
def test_vec_shifter(self, shape, shift_fun, shift_kwargs, test_fun):
def mk_const(x):
if isinstance(shape, ShapeCastable):
return shape.from_bits(0)
else:
return C(x, shape)

width = 8
tests = 50
dut = VecShifterCircuit(shift_fun, shape, width, shift_kwargs(mk_const))

async def test_process(sim: TestbenchContext):
for _ in range(tests):
val = [mk_const(random.randrange(2 ** Shape.cast(shape).width)) for _ in range(width)]
offset = random.randrange(width + 1)
sim.set(dut.input, val)
sim.set(dut.offset, offset)
_, result = await sim.delay(1e-9).sample(dut.output)
assert result == test_fun(val, offset, mk_const)

with self.run_simulation(dut, add_transaction_module=False) as sim:
sim.add_testbench(test_process)
14 changes: 11 additions & 3 deletions transactron/utils/amaranth_ext/shifter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from amaranth import *
from amaranth.hdl import ValueCastable
from collections.abc import Sequence
from typing import Optional, TypeVar, overload
from typing import Optional, TypeVar, cast, overload
from amaranth_types.types import ValueLike
from .functions import shape_of

Expand Down Expand Up @@ -134,7 +134,11 @@ def shift_vec_right(
placeholder: Optional[ValueLike | ValueCastable] = None,
) -> Sequence[Value | ValueCastable]:
if placeholder is None:
placeholder = C(0, shape_of(data[0]))
shape = shape_of(data[0])
if isinstance(shape, Shape):
placeholder = C(0, shape)
else:
placeholder = cast(ValueLike, shape.from_bits(0))
return generic_shift_vec_right(data, [placeholder] * len(data), offset)


Expand All @@ -156,7 +160,11 @@ def shift_vec_left(
placeholder: Optional[ValueLike | ValueCastable] = None,
) -> Sequence[Value | ValueCastable]:
if placeholder is None:
placeholder = C(0, shape_of(data[0]))
shape = shape_of(data[0])
if isinstance(shape, Shape):
placeholder = C(0, shape)
else:
placeholder = cast(ValueLike, shape.from_bits(0))
return generic_shift_vec_left(data, [placeholder] * len(data), offset)


Expand Down

0 comments on commit fb08d2a

Please sign in to comment.