this is a minimal implementation of gpt-2 using jax. i was reading the attetnion is all you need paper and wanted to implement it myself. it's probably very slow and memory inefficient but it's still fun and under 100 lines of code.
-
Notifications
You must be signed in to change notification settings - Fork 0
saurabhaloneai/decoder-only-transformer-in-jax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
decoder only transformer in pure jax.
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published