Soon (Colab). Here's how we implement GPT
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, nn
from einops import rearrange, reduce, repeat
from math import sqrt
# base class
class Module(object):
def __init__(self, param=None, name=None):
self.param = {} if param is None else param
self.name = name
def add(self, module):
self.param[module.name] = module.param
setattr(self, module.name, module)
def __call__(self, param, x, *args):
param = param if self.name is None else param[self.name]
return self.forward(param, x, *args)
def __repr__(self):
def indent(string, spaces):
return '\n'.join(' ' * spaces + line for line in string.split('\n'))
def stringify(param):
strings = []
for name, value in param.items():
if isinstance(value, dict):
strings.append(f'({name}): {getattr(self, name).__repr__()}')
else:
strings.append(f'({name}): Array(shape={value.shape}, dtype={value.dtype})')
return '\n'.join(strings)
string = stringify(self.param)
string = f'(\n{indent(string, 2)}\n)' if len(string) > 0 else '()'
return f'{self.__class__.__name__}' + string
# define submodules
class GPTLayer(Module):
def __init__(self, key, d, nh, name=None):
super().__init__(name=name)
self.d, self.nh = d, nh
key1, key2, key3, key4, key5, key6 = random.split(key, 6)
self.add(Linear(key1, d, 3 * d, name='wx'))
self.add(Linear(key2, d, d, name='wo'))
self.add(
Sequential(
Linear(key3, d, 4 * d),
GELU(),
Linear(key4, 4 * d, d),
name='ffn'
)
)
self.add(LayerNorm(key5, d, name='mhaln'))
self.add(LayerNorm(key6, d, name='ffnln'))
def forward(self, param, x, mask):
q, k, v = rearrange(self.wx(param, self.mhaln(param, x)), 'l (o nh dh) -> o nh l dh', o=3, nh=self.nh)
heads = nn.softmax(jnp.einsum('hic, hjc -> hij', q, k) / sqrt(self.d) + mask, -1) @ v
attn = x + self.wo(param, rearrange(heads, 'nh l dh -> l (nh dh)'))
return x + self.ffn(param, self.ffnln(param, attn))
class GPT(Module):
def __init__(self, key, d, nh, nl, l, v):
super().__init__()
self.d, self.nh, self.nl, self.l, self.v = d, nh, nl, l, v
self.add(Embedding(key, v, d, name='embedding'))
key, subkey = random.split(key)
self.add(Linear(key, d, v, name='out'))
key, subkey = random.split(key)
self.param['position'] = random.normal(subkey, (l, d))
mask = jnp.tril(jnp.ones((l, l))) - 1
mask.at[mask == -1].set(float('-inf'))
self.param['mask'] = mask
keys = random.split(key, nl)
self.add(Sequential(*[GPTLayer(keys[i], d, nh) for i in range(nl)], name='layers'))
def forward(self, param, tokens):
l = tokens.shape[-1]
x = self.embedding(param, tokens) + self.param['position'][:l]
x = self.layers(param, x, self.param['mask'][:l, :l])
return self.out(param, x)
model = GPT(random.PRNGKey(0), 64, 8, 4, 256, 1024)
print(model)
This returns
GPT(
(embedding): Embedding(n_embedding=1024, embed_dim=64)
(out): Linear(in_dim=64, out_dim=1024, bias=True)
(position): Array(shape=(256, 64), dtype=float32)
(mask): Array(shape=(256, 256), dtype=float32)
(layers): Sequential(
(layer0): GPTLayer(
(wx): Linear(in_dim=64, out_dim=192, bias=True)
(wo): Linear(in_dim=64, out_dim=64, bias=True)
(ffn): Sequential(
(layer0): Linear(in_dim=64, out_dim=256, bias=True)
(layer1): GELU()
(layer2): Linear(in_dim=256, out_dim=64, bias=True)
)
(mhaln): LayerNorm(dim=64)
(ffnln): LayerNorm(dim=64)
)
(layer1): GPTLayer(
(wx): Linear(in_dim=64, out_dim=192, bias=True)
(wo): Linear(in_dim=64, out_dim=64, bias=True)
(ffn): Sequential(
(layer0): Linear(in_dim=64, out_dim=256, bias=True)
(layer1): GELU()
(layer2): Linear(in_dim=256, out_dim=64, bias=True)
)
(mhaln): LayerNorm(dim=64)
(ffnln): LayerNorm(dim=64)
)
(layer2): GPTLayer(
(wx): Linear(in_dim=64, out_dim=192, bias=True)
(wo): Linear(in_dim=64, out_dim=64, bias=True)
(ffn): Sequential(
(layer0): Linear(in_dim=64, out_dim=256, bias=True)
(layer1): GELU()
(layer2): Linear(in_dim=256, out_dim=64, bias=True)
)
(mhaln): LayerNorm(dim=64)
(ffnln): LayerNorm(dim=64)
)
(layer3): GPTLayer(
(wx): Linear(in_dim=64, out_dim=192, bias=True)
(wo): Linear(in_dim=64, out_dim=64, bias=True)
(ffn): Sequential(
(layer0): Linear(in_dim=64, out_dim=256, bias=True)
(layer1): GELU()
(layer2): Linear(in_dim=256, out_dim=64, bias=True)
)
(mhaln): LayerNorm(dim=64)
(ffnln): LayerNorm(dim=64)
)
)
)
like PyTorch.