Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Q4_1_O quantization format that preserves outliers in weights and does dot in FP32 #825

Closed
Closed
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
2f51451
Initial commit
saharNooby Mar 30, 2023
873cb95
Make ln0 work correctly
saharNooby Mar 30, 2023
56bf4fc
Implement time mixing, fix matrix shape mismatch
saharNooby Mar 30, 2023
93c8dca
Update README.md
saharNooby Mar 30, 2023
fe272dc
Minor changes
saharNooby Mar 31, 2023
01d667f
Implement exp, max, 1_minus_x, sigmoid operators in ggml
saharNooby Mar 31, 2023
02c9946
Update README.md
saharNooby Mar 31, 2023
d00f285
Add reference implementation of RWKV RNN
saharNooby Mar 31, 2023
61c6b1a
Add comparison against reference implementation script, implement sta…
saharNooby Mar 31, 2023
6fe9486
Finally, FP32 inference
saharNooby Apr 1, 2023
bf88e8a
Update README.md
saharNooby Apr 1, 2023
0fcb7c6
Remove reference implementation code and test against pre-created logits
saharNooby Apr 1, 2023
16ec7a5
Add fail-fast version of the test
saharNooby Apr 1, 2023
fe98c94
[FILE FORMAT CHANGED] Use ggml_get_rows to get embedding
saharNooby Apr 1, 2023
f6d45ba
Support FP16 inference
saharNooby Apr 1, 2023
ac03019
Move model to separate C library file
saharNooby Apr 1, 2023
7130a89
[FILE FORMAT CHANGED] Reverse dimensions in ggml file (makes it more …
saharNooby Apr 1, 2023
a1e1d34
Add Python wrapper for C library
saharNooby Apr 1, 2023
b164bf4
Allocate memory as needed for specific configuration of model
saharNooby Apr 1, 2023
972e28d
Implement INT4 conversion and inference
saharNooby Apr 1, 2023
38f9d02
Fix quantization from FP16
saharNooby Apr 1, 2023
935d16f
Move library wrapper to separate file, refactor code
saharNooby Apr 2, 2023
1ecbad3
Remove unused files
saharNooby Apr 2, 2023
ee46ad2
Add quantization test back, run ggml tests on first context init
saharNooby Apr 2, 2023
e0684e8
Add text generation and chat scripts
saharNooby Apr 2, 2023
6b4ebc3
Update README.md
saharNooby Apr 2, 2023
f2b1dad
Add GitHub workflows file
saharNooby Apr 2, 2023
1262ad0
Fix build errors and warnings
saharNooby Apr 2, 2023
d62a050
Remove hardcoded memory requirements table
saharNooby Apr 2, 2023
a64aaa8
initial addition
hypnopump Apr 2, 2023
5b2830e
Increase memory for overhead from 32 MB to 256 MB
saharNooby Apr 3, 2023
3535476
Update README.md: include info about pre-compiled library
saharNooby Apr 3, 2023
6f3fb01
suggestions
hypnopump Apr 3, 2023
0a0cabc
for consistency
hypnopump Apr 3, 2023
bea02c4
Merge branch 'master' into more_instructions_works_linux
hypnopump Apr 3, 2023
fa74b01
more details for macos/linux
hypnopump Apr 3, 2023
4f1df7c
Merge pull request #9 from hypnopump/more_instructions_works_linux
saharNooby Apr 3, 2023
aacc8b6
Minor formatting changes
saharNooby Apr 3, 2023
977efba
we actually build a dylib on macos
pixelkaiser Apr 4, 2023
77e1998
Merge pull request #13 from pixelkaiser/rwkv-macos
saharNooby Apr 4, 2023
b75a805
working on macos. no point in fp32 if all weights distributed in fp16
hypnopump Apr 4, 2023
f5feb74
verify instructions can be followed
hypnopump Apr 4, 2023
c320573
verify instructions can be followed
hypnopump Apr 4, 2023
a9cb9ad
streaming output
hypnopump Apr 4, 2023
d380134
streaming output
hypnopump Apr 4, 2023
dc679bf
Merge pull request #14 from hypnopump/update_macos
saharNooby Apr 4, 2023
d12088e
Minor formatting changes
saharNooby Apr 5, 2023
ad3a4eb
Add missing labels and symbols for new operators
saharNooby Apr 6, 2023
fa9ad13
Free ggml context when model is garbage collected
saharNooby Apr 5, 2023
058b5cd
Show file compression ratio
saharNooby Apr 4, 2023
ec99bc1
Do not quantize head
saharNooby Apr 6, 2023
c40941d
Add Q4_1_O format
saharNooby Apr 7, 2023
18bf02f
Use ggml function for parameter size calculation
saharNooby Apr 7, 2023
e26b408
Add Q4_1_O test
saharNooby Apr 7, 2023
edd57a1
Update README.md
saharNooby Apr 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add reference implementation of RWKV RNN
  • Loading branch information
saharNooby committed Mar 31, 2023
commit d00f28581af8ff8ee80306ee486d7b95be220832
239 changes: 239 additions & 0 deletions rwkv/rwkv_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Reference implementation of RWKV in PyTorch.

# Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Original code by https://github.com/BlinkDL, licensed under Apache License 2.0

# Improvements made to the original code:
# - safetensors loading support
# - LoRA loading support
# - ln0 absortion support
# - general code style improvements

import time
import torch
import types
from typing import Union, Tuple, Dict, Optional
from torch.nn import functional as F

LORA_R: int = 4
LORA_ALPHA: int = 32

def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]:
print(f'Loading {file_path}')

if file_path.endswith('.safetensors'):
from safetensors import safe_open

w = {}

with safe_open(file_path, framework='pt', device=device) as state_dict:
for key in state_dict.keys():
w[key] = state_dict.get_tensor(key)

return w
else:
return torch.load(file_path, map_location=device)

def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer = 0

while f'blocks.{n_layer}.ln1.weight' in state_dict:
n_layer += 1

assert n_layer > 0

return n_layer

class RWKV_RNN(torch.jit.ScriptModule):

def __init__(
self,
model_path: str,
additional_model_path: Optional[str] = None,
device: str = 'cpu',
absorb_layer_norm_0: bool = False
):
super().__init__()

self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device)
self.eval()

print(f'Loading RWKV model from {model_path}')

w = load_state_dict(model_path, device)

if additional_model_path is not None:
additional_w = load_state_dict(additional_model_path, device)

for k in additional_w:
if k != '_training_state':
w[k] = additional_w[k]

print('Merging LoRA into weights')

start = time.time()

for k in list(w.keys()):
module_k = k.replace('.weight', '')

if module_k + '.lora_A.weight' in w:
lora_A = w[module_k + '.lora_A.weight']
lora_B = w[module_k + '.lora_B.weight']
assert lora_B.shape[1] == lora_A.shape[0] == LORA_R
w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R)
del w[module_k + '.lora_A.weight']
del w[module_k + '.lora_B.weight']
del lora_A
del lora_B

print('Took %.3f sec' % ((time.time() - start),))

for k in w.keys():
if '.time_' in k:
# (1, 1, n_embed) -> (n_embed)
w[k] = w[k].squeeze()

if '.time_decay' in k:
# The real time decay is like e^{-e^x}
w[k] = -torch.exp(w[k].float())
elif w[k].dtype != torch.float32:
w[k] = w[k].float()

self.w = types.SimpleNamespace()
self.w.blocks = {}

# Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w

for p in parts:
if p.isdigit():
p = int(p)

if p not in here:
here[p] = types.SimpleNamespace()

here = here[p]
else:
if not hasattr(here, p):
setattr(here, p, types.SimpleNamespace())

here = getattr(here, p)

setattr(here, last, w[k])

self.absorb_layer_norm_0 = absorb_layer_norm_0

if absorb_layer_norm_0:
print('Absorbing first LayerNorm into embedding matrix')

start = time.time()

for i in range(len(self.w.emb.weight)):
self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0)

print('Took %.3f sec' % ((time.time() - start),))

self.n_layer = get_layer_count(w)
self.n_embed = self.w.emb.weight.shape[1]

def layer_norm(self, x, w):
return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias)

@torch.jit.script_method
def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
state[5 * i + 0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)

@torch.jit.script_method
def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
state[5 * i + 1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv

aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5 * i + 2] = e1 * aa + e2 * v
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = qq
return ow @ (r * wkv)

def warm_up(self):
print('Warming up the model')
start = time.time()
self.forward(0, None)
print('Took %.3f sec' % ((time.time() - start),))

def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
x: torch.Tensor = self.w.emb.weight[token]

if state is None:
state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device)

for i in range(self.n_layer):
# ~Negative infinity
state[5 * i + 4] = -1e30

if not self.absorb_layer_norm_0:
x = self.layer_norm(x, self.w.blocks[0].ln0)

for i in range(self.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(
self.layer_norm(x, self.w.blocks[i].ln1),
state,
i,
att.time_mix_k,
att.time_mix_v,
att.time_mix_r,
att.time_first,
att.time_decay,
att.key.weight,
att.value.weight,
att.receptance.weight,
att.output.weight
)

ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(
self.layer_norm(x, self.w.blocks[i].ln2),
state,
i,
ffn.time_mix_k,
ffn.time_mix_r,
ffn.key.weight,
ffn.value.weight,
ffn.receptance.weight
)

x = self.layer_norm(x, self.w.ln_out)

if save_representation:
self.representation = x.clone()

x = (self.w.head.weight @ x).float()

return x, state