Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed point std_exp #404

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions fixedpoint/fixedpoint.py

This file was deleted.

19 changes: 10 additions & 9 deletions frontends/relay/dahlia_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,27 +383,25 @@ def softmax(fd: DahliaFuncDef) -> str:
"""tvm.apache.org/docs/api/python/relay/nn.html#tvm.relay.nn.softmax"""
data, res = fd.args[0], fd.dest
axis = fd.attributes.get_int("axis")
assert axis == -1 or axis == 1, f'nn.softmax with axis = {axis} is not supported.'
assert axis == -1 or axis == 1, f'softmax with axis = {axis} is not supported.'

data_type = fd.data_type
assert 'fix' in data_type, f'softmax not supported for {data_type}.'
size0, size1, index_size0, index_size1 = data.comp.args[1:5]

# The value of `e` if Q = 32.16, otherwise `3`.
e = '13044242090' if 'fix' in data_type else '3'

return emit_dahlia_definition(
fd,
f"""let e: {data_type} = {e};
f"""
for (let i: ubit<{index_size0}> = 0..{size0}) {{
let {data.id.name}_expsum: {data_type} =
{'0.0' if 'fix' in data_type else '0'};

for (let j: ubit<{index_size1}> = 0..{size1}) {{
let tmp1 = std_exp(e, {data.id.name}[i][j]);
let tmp1 = std_fp_exp({data.id.name}[i][j]);
{data.id.name}_expsum += tmp1;
}}
for (let k: ubit<{index_size1}> = 0..{size1}) {{
let tmp2 = std_exp(e, {data.id.name}[i][k]);
let tmp2 = std_fp_exp({data.id.name}[i][k]);
{res.id.name}[i][k] := tmp2;
{res.id.name}[i][k] :=
{res.id.name}[i][k] / {data.id.name}_expsum;
Expand Down Expand Up @@ -453,14 +451,17 @@ def emit_components(func_defs: List[DahliaFuncDef]) -> str:
# If the function is a binary operation, use broadcasting.
# Otherwise, use the associated Relay function.
apply = broadcast if id in BinaryOps else RelayCallNodes[id]
dahlia_definitions.append(apply(func_def))
dahlia_definitions.append(
apply(func_def)
)

type = func_defs[0].data_type
imports = [
f"""import futil("primitives/bitnum/math.futil")
{{
def std_exp(base: {type}, exp: {type}): {type};
def std_sqrt(in: {type}): {type};
def std_pow(base: {type}, exp: {type}): {type};
def std_fp_exp(exponent: {type}): {type};
}}"""
]

Expand Down
8 changes: 6 additions & 2 deletions frontends/relay/relay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ def get_dahlia_data_type(relay_type) -> str:
Relay | Dahlia
--------|-------------------------------
int | (`bit`, width)
float | (`fix`, (width, width // 2))
float | (`fix`, (32, 28))*

* Currently hard-coded to Q28.4,
to support fixed point `exp`.
"""
width = get_bitwidth(relay_type)

if 'int' in relay_type.dtype:
return f'bit<{width}>'
if 'float' in relay_type.dtype:
return f'fix<{width}, {width // 2}>'
assert width == 32, f'Fixed point of width: {width} not supported.'
return f'fix<{width}, 28>'
assert 0, f'{relay_type} is not supported.'


Expand Down
103 changes: 103 additions & 0 deletions primitives/bitnum/math.futil
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,106 @@ component std_pow(base: 32, exp: 32) -> (out: 32) {
}
}
}

// Computes the unsigned value e^exponent.
// Uses fixed point format Q28.4, and an
// approximation table to estimate the
// fractional power of e.
// TODO(cgyurgyik): Eventually, we want to support
// component gen for fixed point `exp` of any width.
component std_fp_exp(exponent: 32) -> (out: 32) {
cells {
// mathematical constant e in Q28.4
e = std_const(32, 43);
// Approximation of e using chebyshev
// polynomials within bounds [0, 1].
e_table = std_mem_d1(32, 16, 32);

pow = std_reg(32);
mul = fixed_p_std_mult(32, 28, 4);

int_bits = std_reg(32);
frac_bits = std_reg(32);
and0 = std_and(32);
rsh0 = std_rsh(32);
and1 = std_and(32);

integer_count = std_reg(32);
lt = std_lt(32);
incr = std_add(32);
fractional_value = std_reg(32);
}
wires {
group init0 {
// Mask integer bits, and shift right.
and0.left = exponent;
and0.right = 32'd4294967280;
rsh0.left = and0.out;
rsh0.right = 32'd4;
int_bits.in = rsh0.out;

// Mask fractional bits.
and1.left = exponent;
and1.right = 32'd15;
frac_bits.in = and1.out;

int_bits.write_en = 1'd1;
frac_bits.write_en = 1'd1;
init0[done] = int_bits.done & frac_bits.done ? 1'd1;
}
group init1 {
pow.in = 32'd16; // 1.0 in Q28.4
pow.write_en = 1'd1;
integer_count.in = 32'd0;
integer_count.write_en = 1'd1;
init1[done] = pow.done & integer_count.done ? 1'd1;
}
group do_mul {
mul.left = e.out;
mul.right = pow.out;
pow.in = mul.out;
pow.write_en = 1'd1;
do_mul[done] = pow.done;
}
group incr_count {
incr.left = 32'd1;
incr.right = integer_count.out;
integer_count.in = incr.out;
integer_count.write_en = 1'd1;
incr_count[done] = integer_count.done;
}
group cond {
lt.right = int_bits.out;
lt.left = integer_count.out;
cond[done] = 1'd1;
}
group get_fractional_value {
e_table.addr0 = frac_bits.out;
fractional_value.in = e_table.read_data;
fractional_value.write_en = 1'd1;
get_fractional_value[done] = fractional_value.done;
}
group mult_int_frac {
mul.left = pow.out;
mul.right = fractional_value.out;
pow.in = mul.out;
pow.write_en = 1'd1;
mult_int_frac[done] = pow.done;
}

out = pow.out;
}
control {
seq {
par { init0; init1; }
// Compute e^i, where i is the integer value.
while lt.out with cond {
par { do_mul; incr_count; }
}
// Lookup e^f, where f is the fractional value.
get_fractional_value;
// Compute e^x = e^i * e^f.
mult_int_frac;
}
}
}
1 change: 0 additions & 1 deletion primitives/bitnum/math.sv
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,3 @@ module std_sqrt (
`endif

endmodule

155 changes: 155 additions & 0 deletions primitives/fixed-point-gen/fp-exp-table-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
from itertools import product


def decimal_to_fixed_p(num, width, int_bit, frac_bit):
"""Given the number, width, integer bit and fractional bit,
returns the fixed point representation. If the fraction
cannot be represented exactly in fixed point, it will be
rounded to the nearest whole number.

Example:
decimal_to_fixed_p(11.125,8,5,3)
returns 01011001 = 2^3+2^1+2^0+2^(-3)
Preconditions:
1. There is no overflow.
2. Integer part of the number should be
representable with int_bit number of bits.
"""
# separate into integer and fractional parts
intg, frac = str(num).split(".")
int_b = np.binary_repr(int(intg), width=int_bit)
frac = "0." + frac

# multiply fractional part with 2**frac_bit to turn into integer
frac = float(frac) * float(2 ** frac_bit)
_, f = str(frac).split(".")

# Rounds up if the tenths place is >= 5.
tenths_place = int(f[0])
frac = int(frac) if tenths_place < 5 else int(frac + 1)

frac_b = np.binary_repr(frac, width=frac_bit)
r = int_b + frac_b
return r


def fixed_p_to_decimal(fp, width, int_bit, frac_bit):
"""Given fixedpoint representation, width,
integer bit and fractinal bit, returns the number.
example: fixed_p_to_decimal ('01011001',8,5,3) returns 11.125
"""
int_b = fp[:int_bit]
frac_b = fp[int_bit:width]
int_num = int(int_b, 2)
frac = int(frac_b, 2)
frac_num = float(frac / 2 ** (frac_bit))
num = float(int_num + frac_num)
return num


def binary_to_base10(bit_list):
"""Takes a binary number in list form
e.g. [1, 0, 1, 0], and returns
the corresponding base 10 number.
"""
out = 0
for b in bit_list:
out = (out << 1) | b
return out


def compute_exp_frac_table(frac_bit):
"""Computes a table of size 2^frac_bit
for every value of e^x that can be
represented by fixed point in the range [0, 1].
"""
# Chebyshev approximation coefficients for e^x in [0, 1].
# Credits to J. Sach's blogpost here:
# https://www.embeddedrelated.com/showarticle/152.php
coeffs = [
1.7534,
0.85039,
0.10521,
0.0087221,
0.00054344,
0.000027075
]

def chebyshev_polynomial_approx(x):
"""Computes the Chebyshev polynomials
based on the recurrence relation
described here:
en.wikipedia.org/wiki/Chebyshev_polynomials#Definition
"""
# Change from [0, 1] to [-1, 1] for
# better approximation with chebyshev.
u = (2 * x - 1)

Ti = 1
Tn = None
T = u
num_coeffs = len(coeffs)
c = coeffs[0]
for i in range(1, num_coeffs):
c = c + T * coeffs[i]
Tn = 2 * u * T - Ti
Ti = T
T = Tn

return c

# Gets the permutations of 2^f_bit,
# in increasing order.
binary_permutations = map(
list,
product(
[0, 1],
repeat=frac_bit
)
)

e_table = [0] * (2 ** frac_bit)
for p in binary_permutations:
i = binary_to_base10(p)
fraction = float(
i / 2 ** (frac_bit)
)
e_table[i] = chebyshev_polynomial_approx(fraction)

return e_table


def exp(x, width, int_bit, frac_bit, print_results=False):
"""
Computes an approximation of e^x.
This is done by splitting the fixed point number
x into its integral bits `i` and fractional bits `f`,
and computing e^(i + f) = e^i * e^f.

For the fractional portion, a chebyshev
approximation is used.
"""
fp_x = decimal_to_fixed_p(x, width, int_bit, frac_bit)

int_b = fp_x[:int_bit]
int_bin = int(int_b, 2)
frac_b = fp_x[int_bit:width]
frac_bin = int(frac_b, 2)

# Split e^x into e^i * e^f.
e_i = 2.71828 ** int_bin

e_table = compute_exp_frac_table(frac_bit)
e_f = e_table[frac_bin]

# Compute e^i * e^f.
actual = e_i * e_f

if print_results:
accepted = 2.71828 ** x
print(f'e^{x}: {accepted}')
print(f'actual: {actual}')
print(f'relative difference: {(actual - accepted) / actual * 100}%')

return actual
1 change: 0 additions & 1 deletion primitives/fixed/signed.sv
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,3 @@ module fixed_p_std_sdiv #(
assign result = $signed(left / right);
assign out = result[WIDTH+FRACT_WIDTH-1:FRACT_WIDTH];
endmodule

3 changes: 1 addition & 2 deletions primitives/fixed/unsigned.futil
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,4 @@ extern "unsigned.sv" {
WIDTH, INT_WIDTH, FRACT_WIDTH
](left: WIDTH, right: WIDTH) -> (out: WIDTH);


}
}
Loading