Skip to content

Commit

Permalink
Use NumericType classes, additional tests, transition to simplejson. (
Browse files Browse the repository at this point in the history
#468)

* Initial commit.

* Formatting.

* Remove double defn.

* use simplejson.

* Add simplejson requirement.

* Add workflow update.

* Fud

* Cleanup, use errors.

* Use value.

* Add more tests.

* Add type.

* Fix parsing.

* Fud

* Update path.

* Fud

* Fud

* Trailing whitespace.

* Fud

* Whitespace, address Rachit's comments.

* Remove t
  • Loading branch information
cgyurgyik authored Mar 29, 2021
1 parent 233c7d6 commit 03ad6d4
Show file tree
Hide file tree
Showing 9 changed files with 456 additions and 329 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
uses: actions/setup-python@v2

- name: Install common Python dependencies
run: pip3 install numpy flit prettytable wheel hypothesis pytest
run: pip3 install numpy flit prettytable wheel hypothesis pytest simplejson

- name: Cache Apache TVM
id: tvm-cache
Expand Down Expand Up @@ -180,7 +180,7 @@ jobs:
# Run the remaining tests
runt -x dahlia -d -o fail
- name: Run Python Tests
run: pytest fud/fud/stages/verilator/tests/fp_parse.py
run: pytest fud/fud/stages/verilator/tests/numeric_types.py

format:
name: Check Formatting
Expand Down
7 changes: 2 additions & 5 deletions fud/fud/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ class InvalidNumericType(FudError):
An error raised when an invalid numeric type is provided.
"""

def __init__(self, type):
msg = f"""Numeric type: {type} is not supported.
Give a valid numeric type input. We currently support:
(1) bitnum
(2) fixed point"""
def __init__(self, msg):
msg = f"""Invalid Numeric Type: {msg}"""
super().__init__(msg)


Expand Down
192 changes: 0 additions & 192 deletions fud/fud/stages/verilator/fixed_point.py

This file was deleted.

97 changes: 27 additions & 70 deletions fud/fud/stages/verilator/json_to_dat.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import json
import simplejson as sjson
import numpy as np
from .fixed_point import fp_to_decimal, decimal_to_fp
from .numeric_types import FixedPoint, Bitnum
from pathlib import Path
from fud.errors import InvalidNumericType, Malformed


# Converts `val` into a bitstring that is `bw` characters wide
def bitstring(val, bw):
# first truncate val by `bw` bits
val &= 2 ** bw - 1
return "{:x}".format(val)


def parse_dat_bitnum(path, bw, is_signed):
"""Parses bitnum numbers of bit width `bw`
from the array at the given `path`.
def parse_dat(path, args):
"""Parses a number with the given numeric type
arguments from the array at the given `path`.
"""

if not path.exists():
raise Malformed(
"Data directory",
Expand All @@ -29,40 +21,15 @@ def parse_dat_bitnum(path, bw, is_signed):
),
)

def to_decimal(hex_v: str) -> int:
# Takes in a value in string
# hexadecimal form, and
# returns the corresponding
# integer value.
v = int(hex_v.strip(), 16)
if is_signed and v > (2 ** (bw - 1)):
return -1 * ((2 ** bw) - v)

return v

with path.open("r") as f:
return np.array(list(map(to_decimal, f.readlines())))


def parse_dat_fp(path, width, int_width, is_signed):
"""Parses fixed point numbers in the array
at `path` with the following form:
Total width: `width`
Integer width: `int_width`
Fractional width: `width` - `int_width`
"""

def hex_to_decimal(v):
# Given a fixed point number in hexadecimal form,
# returns the string form of the decimal value.
decimal_v = fp_to_decimal(
np.binary_repr(int(v.strip(), 16), width), width, int_width, is_signed
)
# Stringified since Decimal is not JSON serializable.
return str(decimal_v)
def parse(hex_value: str):
hex_value = f"0x{hex_value}"
if "int_width" in args:
return FixedPoint(hex_value, **args).str_value()
else:
return int(Bitnum(hex_value, **args).str_value())

with path.open("r") as f:
return np.array(list(map(hex_to_decimal, f.readlines())))
return np.array([parse(hex_value) for hex_value in f.readlines()])


def parse_fp_widths(format):
Expand Down Expand Up @@ -111,20 +78,15 @@ def convert2dat(output_dir, data, extension):
for k, item in data.items():
path = output_dir / f"{k}.{extension}"
path.touch()
arr = np.array(item["data"])
arr = np.array(item["data"], str)
format = item["format"]

# Every numeric format shares these two fields.
numeric_type = format["numeric_type"]
is_signed = format["is_signed"]
shape[k] = {
"shape": list(arr.shape),
"numeric_type": numeric_type,
"is_signed": is_signed,
}
shape[k] = {"is_signed": is_signed}

if numeric_type not in {"bitnum", "fixed_point"}:
raise InvalidNumericType(numeric_type)
raise InvalidNumericType('Fud only supports "fixed_point" and "bitnum".')

is_fp = numeric_type == "fixed_point"
if is_fp:
Expand All @@ -135,18 +97,21 @@ def convert2dat(output_dir, data, extension):
width = format["width"]
shape[k]["width"] = width

convert = (
lambda x: decimal_to_fp(x, width, int_width, is_signed) if is_fp else x
)
def convert(x):
NumericType = FixedPoint if is_fp else Bitnum
return NumericType(x, **shape[k]).hex_string(with_prefix=False)

with path.open("w") as f:
for v in arr.flatten():
f.write(bitstring(convert(v), width) + "\n")
f.write(convert(v) + "\n")

shape[k]["shape"] = list(arr.shape)
shape[k]["numeric_type"] = numeric_type

# Commit shape.json file.
shape_json_file = output_dir / "shape.json"
with shape_json_file.open("w") as f:
json.dump(shape, f, indent=2)
sjson.dump(shape, f, indent=2, use_decimal=True)


def convert2json(input_dir, extension):
Expand All @@ -161,21 +126,13 @@ def convert2json(input_dir, extension):
return {}

data = {}
shape_json = json.load(shape_json_path.open("r"))
shape_json = sjson.load(shape_json_path.open("r"), use_decimal=True)

for (mem, form) in shape_json.items():
path = input_dir / f"{mem}.{extension}"
numeric_type = form["numeric_type"]
is_signed = form["is_signed"]
width = form["width"]

if numeric_type == "bitnum":
arr = parse_dat_bitnum(path, width, is_signed)
elif numeric_type == "fixed_point":
arr = parse_dat_fp(path, width, form["int_width"], is_signed)
else:
raise InvalidNumericType(numeric_type)

args = form.copy()
args.pop("shape"), args.pop("numeric_type")
arr = parse_dat(path, args)
if form["shape"] == [0]:
raise Malformed(
"Data format shape",
Expand Down
Loading

0 comments on commit 03ad6d4

Please sign in to comment.