-
Notifications
You must be signed in to change notification settings - Fork 4
/
sample.py
153 lines (119 loc) · 5.32 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import json
import string
from types import SimpleNamespace
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from transformers import AutoTokenizer
from mamba_jax.kernels import KernelTypeMapping
from mamba_jax.modelling.equinox import MambaLLM
from mamba_jax.modelling.equinox.loader import load_pretrained
# simple tokenizer for Text8 generation task
class Text8Tokenizer:
def __init__(self):
lower, upper = string.ascii_lowercase[0], string.ascii_lowercase[-1]
self.lower, self.upper = ord(lower), ord(upper)
self.space = self.upper + 1
self.eos_token_id = -1
def __call__(self, prompt, *args, **kwargs):
prompt = prompt.lower()
assert all(c in string.ascii_lowercase + " " for c in prompt)
bytes = np.array([ord(c) for c in prompt], dtype=int)
bytes[bytes == ord(" ")] = self.space
input_ids = bytes - self.lower
return SimpleNamespace(input_ids=[input_ids])
def decode(self, input_ids, *args, **kwargs):
if isinstance(input_ids, list):
bytes = [input_ids[0] + self.lower]
else:
bytes = input_ids + self.lower
decoded = "".join([chr(b) if b != self.space else " " for b in bytes])
return decoded
def load_local_model(args):
with open(args.config, mode="r") as f:
config = json.load(f)
config = SimpleNamespace(**config)
# TODO: move this under a 'model_config' sub-dictionary so we can just do **config.model_config
model_kwargs = MambaLLM.args_namespace_to_model_kwargs(config)
model = MambaLLM(**model_kwargs, key=jax.random.PRNGKey(0))
model = eqx.tree_deserialise_leaves(args.model, like=model)
# TODO: for now, the training script only trains on text8, so we just use
# tokeniser for that, which can be a simple function
if args.text8_tokenizer:
tokenizer = Text8Tokenizer()
else:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
return model, tokenizer
def main(args):
# TODO: make this more robust as future models may not be under state-spaces namespace
if args.model.startswith("state-spaces"):
model, tokenizer = load_pretrained(
args.model,
dtype=jnp.bfloat16 if args.bf16 else jnp.float32,
kernel_mode=KernelTypeMapping[args.kernel_mode],
)
else: # is probably local model
model, tokenizer = load_local_model(args)
prompt = args.prompt
gen_len = args.gen_len
temperature = args.temperature
generate_step = eqx.filter_jit(model.generate_step)
generate_fn = None
if args.scan:
generate_fn = eqx.filter_jit(model.generate)
for p in range(args.seed_iters):
output = []
cache = model.init_cache()
input_ids = tokenizer(prompt, return_tensors="np").input_ids[0]
key = jax.random.PRNGKey(args.seed + p)
if generate_fn is not None:
output_ids = generate_fn(input_ids, gen_len, temperature=temperature, key=key)
print(tokenizer.decode(output_ids, skip_special_tokens=True))
print("\n\n")
continue
print()
print(tokenizer.decode(input_ids, skip_special_tokens=True), flush=True, end="")
# "prefill"
for input_id in input_ids[:-1]:
_, cache = generate_step(input_id, cache=cache)
input_id = input_ids[-1]
output = input_ids.tolist()
for _ in range(gen_len):
logits, cache = generate_step(input_id, cache=cache)
logits = logits / temperature
key, subkey = jax.random.split(key)
input_id = jax.random.categorical(subkey, logits)
if input_id == tokenizer.eos_token_id:
print("\n\n\n")
break
output.append(input_id)
print(tokenizer.decode([output[-1]]), flush=True, end="")
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default="Aloha, World! ", help="Starting prompt for generation.")
parser.add_argument(
"--model",
type=str,
default="state-spaces/mamba-2.8b",
help="Local JAX model checkpoint or PyTorch model repo id as on Huggingface Hub.",
)
parser.add_argument(
"--config", type=str, default=None, help="Path to model config file if using a local JAX model checkpoint."
)
parser.add_argument(
"--bf16",
action="store_true",
help="Use bfloat16 for inference. If using `--config` the value there will overwrite this flag.",
)
parser.add_argument("--gen_len", type=int, default=1024, help="Length of generated sequence.")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature.")
parser.add_argument("--seed", type=int, default=0, help="Random seed for PRNG initialisation.")
parser.add_argument("--seed_iters", type=int, default=1, help="Number of seeds to generate, starting from --seed.")
parser.add_argument("--scan", action="store_true", help="Use jax.lax.scan version of generate loop.")
parser.add_argument(
"--text8_tokenizer", action="store_true", help="Use the text8 tokenizer. Only use for models trained on text8"
)
args = parser.parse_args()
main(args)