A Rudimentary DL Library Using JAX

Soon (Colab). Here's how we implement GPT

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, nn
from einops import rearrange, reduce, repeat
from math import sqrt


# base class
class Module(object):
  def __init__(self, param=None, name=None):
    self.param = {} if param is None else param
    self.name = name

  def add(self, module):
    self.param[module.name] = module.param
    setattr(self, module.name, module)

  def __call__(self, param, x, *args):
    param = param if self.name is None else param[self.name]
    return self.forward(param, x, *args)

  def __repr__(self):
    def indent(string, spaces):
      return '\n'.join(' ' * spaces + line for line in string.split('\n'))

    def stringify(param):
      strings = []
      for name, value in param.items():
        if isinstance(value, dict):
          strings.append(f'({name}): {getattr(self, name).__repr__()}')
        else:
          strings.append(f'({name}): Array(shape={value.shape}, dtype={value.dtype})')
      return '\n'.join(strings)

    string = stringify(self.param)
    string = f'(\n{indent(string, 2)}\n)' if len(string) > 0 else '()'
    return f'{self.__class__.__name__}' + string


# define submodules


class GPTLayer(Module):
  def __init__(self, key, d, nh, name=None):
    super().__init__(name=name)
    self.d, self.nh = d, nh
    key1, key2, key3, key4, key5, key6 = random.split(key, 6)
    self.add(Linear(key1, d, 3 * d, name='wx'))
    self.add(Linear(key2, d, d, name='wo'))
    self.add(
      Sequential(
        Linear(key3, d, 4 * d),
        GELU(),
        Linear(key4, 4 * d, d),
        name='ffn'
      )
    )
    self.add(LayerNorm(key5, d, name='mhaln'))
    self.add(LayerNorm(key6, d, name='ffnln'))

  def forward(self, param, x, mask):
    q, k, v = rearrange(self.wx(param, self.mhaln(param, x)), 'l (o nh dh) -> o nh l dh', o=3, nh=self.nh)
    heads = nn.softmax(jnp.einsum('hic, hjc -> hij', q, k) / sqrt(self.d) + mask, -1) @ v
    attn = x + self.wo(param, rearrange(heads, 'nh l dh -> l (nh dh)'))
    return x + self.ffn(param, self.ffnln(param, attn))


class GPT(Module):
  def __init__(self, key, d, nh, nl, l, v):
    super().__init__()
    self.d, self.nh, self.nl, self.l, self.v = d, nh, nl, l, v

    self.add(Embedding(key, v, d, name='embedding'))

    key, subkey = random.split(key)
    self.add(Linear(key, d, v, name='out'))

    key, subkey = random.split(key)
    self.param['position'] = random.normal(subkey, (l, d))

    mask = jnp.tril(jnp.ones((l, l))) - 1
    mask.at[mask == -1].set(float('-inf'))
    self.param['mask'] = mask

    keys = random.split(key, nl)
    self.add(Sequential(*[GPTLayer(keys[i], d, nh) for i in range(nl)], name='layers'))

  def forward(self, param, tokens):
    l = tokens.shape[-1]
    x = self.embedding(param, tokens) + self.param['position'][:l]
    x = self.layers(param, x, self.param['mask'][:l, :l])
    return self.out(param, x)


model = GPT(random.PRNGKey(0), 64, 8, 4, 256, 1024)
print(model)

This returns

GPT(
  (embedding): Embedding(n_embedding=1024, embed_dim=64)
  (out): Linear(in_dim=64, out_dim=1024, bias=True)
  (position): Array(shape=(256, 64), dtype=float32)
  (mask): Array(shape=(256, 256), dtype=float32)
  (layers): Sequential(
    (layer0): GPTLayer(
      (wx): Linear(in_dim=64, out_dim=192, bias=True)
      (wo): Linear(in_dim=64, out_dim=64, bias=True)
      (ffn): Sequential(
        (layer0): Linear(in_dim=64, out_dim=256, bias=True)
        (layer1): GELU()
        (layer2): Linear(in_dim=256, out_dim=64, bias=True)
      )
      (mhaln): LayerNorm(dim=64)
      (ffnln): LayerNorm(dim=64)
    )
    (layer1): GPTLayer(
      (wx): Linear(in_dim=64, out_dim=192, bias=True)
      (wo): Linear(in_dim=64, out_dim=64, bias=True)
      (ffn): Sequential(
        (layer0): Linear(in_dim=64, out_dim=256, bias=True)
        (layer1): GELU()
        (layer2): Linear(in_dim=256, out_dim=64, bias=True)
      )
      (mhaln): LayerNorm(dim=64)
      (ffnln): LayerNorm(dim=64)
    )
    (layer2): GPTLayer(
      (wx): Linear(in_dim=64, out_dim=192, bias=True)
      (wo): Linear(in_dim=64, out_dim=64, bias=True)
      (ffn): Sequential(
        (layer0): Linear(in_dim=64, out_dim=256, bias=True)
        (layer1): GELU()
        (layer2): Linear(in_dim=256, out_dim=64, bias=True)
      )
      (mhaln): LayerNorm(dim=64)
      (ffnln): LayerNorm(dim=64)
    )
    (layer3): GPTLayer(
      (wx): Linear(in_dim=64, out_dim=192, bias=True)
      (wo): Linear(in_dim=64, out_dim=64, bias=True)
      (ffn): Sequential(
        (layer0): Linear(in_dim=64, out_dim=256, bias=True)
        (layer1): GELU()
        (layer2): Linear(in_dim=256, out_dim=64, bias=True)
      )
      (mhaln): LayerNorm(dim=64)
      (ffnln): LayerNorm(dim=64)
    )
  )
)

like PyTorch.