Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NumericType classes, additional tests, transition to simplejson. #468

Merged
merged 21 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
uses: actions/setup-python@v2

- name: Install common Python dependencies
run: pip3 install numpy flit prettytable wheel hypothesis pytest
run: pip3 install numpy flit prettytable wheel hypothesis pytest simplejson

- name: Cache Apache TVM
id: tvm-cache
Expand Down Expand Up @@ -180,7 +180,7 @@ jobs:
# Run the remaining tests
runt -x dahlia -d -o fail
- name: Run Python Tests
run: pytest fud/fud/stages/verilator/tests/fp_parse.py
run: pytest fud/fud/stages/verilator/tests/numeric_types.py

format:
name: Check Formatting
Expand Down
7 changes: 2 additions & 5 deletions fud/fud/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ class InvalidNumericType(FudError):
An error raised when an invalid numeric type is provided.
"""

def __init__(self, type):
msg = f"""Numeric type: {type} is not supported.
Give a valid numeric type input. We currently support:
(1) bitnum
(2) fixed point"""
def __init__(self, msg):
msg = f"""Invalid Numeric Type: {msg}"""
super().__init__(msg)


Expand Down
192 changes: 0 additions & 192 deletions fud/fud/stages/verilator/fixed_point.py

This file was deleted.

97 changes: 27 additions & 70 deletions fud/fud/stages/verilator/json_to_dat.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import json
import simplejson as sjson
import numpy as np
from .fixed_point import fp_to_decimal, decimal_to_fp
from .numeric_types import FixedPoint, Bitnum
from pathlib import Path
from fud.errors import InvalidNumericType, Malformed


# Converts `val` into a bitstring that is `bw` characters wide
def bitstring(val, bw):
# first truncate val by `bw` bits
val &= 2 ** bw - 1
return "{:x}".format(val)


def parse_dat_bitnum(path, bw, is_signed):
"""Parses bitnum numbers of bit width `bw`
from the array at the given `path`.
def parse_dat(path, args):
"""Parses a number with the given numeric type
arguments from the array at the given `path`.
"""

if not path.exists():
raise Malformed(
"Data directory",
Expand All @@ -29,40 +21,15 @@ def parse_dat_bitnum(path, bw, is_signed):
),
)

def to_decimal(hex_v: str) -> int:
# Takes in a value in string
# hexadecimal form, and
# returns the corresponding
# integer value.
v = int(hex_v.strip(), 16)
if is_signed and v > (2 ** (bw - 1)):
return -1 * ((2 ** bw) - v)

return v

with path.open("r") as f:
return np.array(list(map(to_decimal, f.readlines())))


def parse_dat_fp(path, width, int_width, is_signed):
"""Parses fixed point numbers in the array
at `path` with the following form:
Total width: `width`
Integer width: `int_width`
Fractional width: `width` - `int_width`
"""

def hex_to_decimal(v):
# Given a fixed point number in hexadecimal form,
# returns the string form of the decimal value.
decimal_v = fp_to_decimal(
np.binary_repr(int(v.strip(), 16), width), width, int_width, is_signed
)
# Stringified since Decimal is not JSON serializable.
return str(decimal_v)
def parse(hex_value: str):
hex_value = f"0x{hex_value}"
if "int_width" in args:
return FixedPoint(hex_value, **args).str_value()
else:
return int(Bitnum(hex_value, **args).str_value())

with path.open("r") as f:
return np.array(list(map(hex_to_decimal, f.readlines())))
return np.array([parse(hex_value) for hex_value in f.readlines()])


def parse_fp_widths(format):
Expand Down Expand Up @@ -111,20 +78,15 @@ def convert2dat(output_dir, data, extension):
for k, item in data.items():
path = output_dir / f"{k}.{extension}"
path.touch()
arr = np.array(item["data"])
arr = np.array(item["data"], str)
format = item["format"]

# Every numeric format shares these two fields.
numeric_type = format["numeric_type"]
is_signed = format["is_signed"]
shape[k] = {
"shape": list(arr.shape),
"numeric_type": numeric_type,
"is_signed": is_signed,
}
shape[k] = {"is_signed": is_signed}

if numeric_type not in {"bitnum", "fixed_point"}:
raise InvalidNumericType(numeric_type)
raise InvalidNumericType('Fud only supports "fixed_point" and "bitnum".')

is_fp = numeric_type == "fixed_point"
if is_fp:
Expand All @@ -135,18 +97,21 @@ def convert2dat(output_dir, data, extension):
width = format["width"]
shape[k]["width"] = width

convert = (
lambda x: decimal_to_fp(x, width, int_width, is_signed) if is_fp else x
)
def convert(x):
NumericType = FixedPoint if is_fp else Bitnum
return NumericType(x, **shape[k]).hex_string(with_prefix=False)

with path.open("w") as f:
for v in arr.flatten():
f.write(bitstring(convert(v), width) + "\n")
f.write(convert(v) + "\n")

shape[k]["shape"] = list(arr.shape)
shape[k]["numeric_type"] = numeric_type

# Commit shape.json file.
shape_json_file = output_dir / "shape.json"
with shape_json_file.open("w") as f:
json.dump(shape, f, indent=2)
sjson.dump(shape, f, indent=2, use_decimal=True)


def convert2json(input_dir, extension):
Expand All @@ -161,21 +126,13 @@ def convert2json(input_dir, extension):
return {}

data = {}
shape_json = json.load(shape_json_path.open("r"))
shape_json = sjson.load(shape_json_path.open("r"), use_decimal=True)

for (mem, form) in shape_json.items():
path = input_dir / f"{mem}.{extension}"
numeric_type = form["numeric_type"]
is_signed = form["is_signed"]
width = form["width"]

if numeric_type == "bitnum":
arr = parse_dat_bitnum(path, width, is_signed)
elif numeric_type == "fixed_point":
arr = parse_dat_fp(path, width, form["int_width"], is_signed)
else:
raise InvalidNumericType(numeric_type)

args = form.copy()
args.pop("shape"), args.pop("numeric_type")
arr = parse_dat(path, args)
if form["shape"] == [0]:
raise Malformed(
"Data format shape",
Expand Down
Loading