Skip to content

Commit

Permalink
xclrun pass in floating point
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Oct 28, 2024
1 parent 61e886b commit b3ee79a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions fud/fud/xclrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def convert_to_fp(value: float):

convert_to_fp(buf)
return list(buf)
elif fmt["numeric_type"] == "bitnum":
return list([int(e) for e in buf])
elif fmt["numeric_type"] in {"bitnum", "floating_point"}:
return [int(value) if isinstance(value, np.integer) else float(value) for value in buf]

else:
raise InvalidNumericType('Fud only supports "fixed_point" and "bitnum".')
raise InvalidNumericType('Fud only supports "fixed_point", "bitnum", and "floating_point".')


def run(xclbin: Path, data: Mapping[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -110,7 +110,10 @@ def run(xclbin: Path, data: Mapping[str, Any]) -> Dict[str, Any]:
def _dtype(fmt) -> np.dtype:
# See https://numpy.org/doc/stable/reference/arrays.dtypes.html for typing
# details
type_string = "i" if fmt["is_signed"] else "u"
if (fmt["numeric_type"] == "floating_point"):
type_string = "f"
else:
type_string = "i" if fmt["is_signed"] else "u"
byte_size = int(fmt["width"] / 8)
type_string = type_string + str(byte_size)
return np.dtype(type_string)
Expand Down

0 comments on commit b3ee79a

Please sign in to comment.