Automatic differentiation

more posts


Finite difference method

The simplest way to compute the derivative of a function comes from the definition of a partial derivative

$$\frac{\partial f}{\partial x_i} = \lim_{h \rightarrow 0}\frac{f(x_1,\ldots,x_i + h,\ldots,x_n) - f(x_1,\ldots,x_n)}{h}.$$

For some function

# define some function
def f(a, b, c):
  return 2 * a * b * c * (a ** 2 + b ** 2 + c ** 2)

to compute the derivative at some point (a, b, c)

# smaller h => more accuracy
h = 0.0001
# definition of derivative
df_over_da = (f(a + h, b, c) - f(a, b, c)) / h

print(df_over_da)

This method is very slow because you need to evaluate the function $O(P)$ times to compute the gradient (where $P$ is the # parameters). What if the function has billions of parameters (neural nets)? What if you need to compute the gradient at billions of points?


Reverse-mode automatic differentiation (RMAD)

The basic idea of automatic differentiation is to break down a complex function into a series of primitive operations and use the chain rule to compute the derivatives recursively. For example, consider the function

$$f(x_1, x_2) = \frac{\log(x_1x_2(x_1 + x_2))}{(x_1 - x_2)^2}.$$

Our goal is to find the derivative of $f$ w.r.t $x_1$. In Python code, we'll store the output of the function for inputs x1 and x2 in some object y. We then call some method of y that'll compute the derivative of y w.r.t it's parameters (here, x1 and x2). How do we do this? We start by breaking the above down into simple terms like so

\begin{eqnarray} t_1 &=& x_1 + x_2 \\ t_2 &=& x_1x_2 \\ t_3 &=& t_1t_2 \\ t_4 &=& \log(t_3) \\ t_5 &=& x_1 - x_2 \\ t_6 &=& t_5^2 \\ f(x, y) = t_7 &=& \frac{t_4}{t_6}. \end{eqnarray}

A "term" should be of one of the following forms

\begin{eqnarray} t_i &=& u(x_j) \\ t_i &=& x_j \circ x_k \\ t_i &=& u(t_j) \\ t_i &=& t_j \circ t_k \\ t_i &=& t_j \circ x_k \end{eqnarray}

where $\circ$ is some binary operation (e.g., adding), $u$ is some unary operation (e.g., log), and $i$, $j$, $k$ may or may not be equal to one another. We start from $t_7$ and work our way to lesser terms using the chain rule

\begin{eqnarray} \frac{\partial t_7}{\partial t_7} &=& 1 \\ \frac{\partial t_7}{\partial t_6} &=& \frac{-t_4}{t_6^2} \\ \frac{\partial t_7}{\partial t_4} &=& \frac{1}{t_6} \\ \frac{\partial t_7}{\partial t_5} &=& 2t_5\frac{\partial t_7}{\partial t_6} \\ \frac{\partial t_5}{\partial x_1} &=& 1 \tag{1} \\ \frac{\partial t_7}{\partial t_3} &=& \frac{1}{t_3}\frac{\partial t_7}{\partial t_4} \\ \frac{\partial t_7}{\partial t_2} &=& t_1\frac{\partial t_7}{\partial t_3} \\ \frac{\partial t_7}{\partial t_1} &=& t_2\frac{\partial t_7}{\partial t_3} \\ \frac{\partial t_2}{\partial x_1} &=& x_2 \tag{2} \\ \frac{\partial t_1}{\partial x_1} &=& 1. \tag{3} \end{eqnarray}

Consider $(1)$, $(2)$, and $(3)$. The neat thing is that

$$\frac{\partial f}{\partial x_1} = \frac{\partial t_7}{\partial x_1} = \frac{\partial t_7}{\partial t_5}\frac{\partial t_5}{\partial x_1} + \frac{\partial t_7}{\partial t_2}\frac{\partial t_2}{\partial x_1} + \frac{\partial t_7}{\partial t_1}\frac{\partial t_1}{\partial x_1}.$$

We keep subsituting the values of the partial derivative of $t_7$ w.r.t to the intermediate terms until the final expression is only dependent on $t_1,\ldots,t_n$ and $x_1,\ldots,x_n$

\begin{eqnarray} \frac{\partial f}{\partial x_1} &=& \frac{\partial t_7}{\partial t_5}\frac{\partial t_5}{\partial x_1} + \frac{\partial t_7}{\partial t_2}\frac{\partial t_2}{\partial x_1} + \frac{\partial t_7}{\partial t_1}\frac{\partial t_1}{\partial x_1} \\ &=& \frac{\partial t_7}{\partial t_5} + x_2\frac{\partial t_7}{\partial t_2} + \frac{\partial t_7}{\partial t_1} \\ &=& 2t_5\frac{\partial t_7}{\partial t_6} + x_2t_1\frac{\partial t_7}{\partial t_3} + t_2\frac{\partial t_7}{\partial t_3} \\ &=& \frac{-2t_5t_4}{t_6^2} + \frac{x_2t_1}{t_3}\frac{\partial t_7}{\partial t_4} + \frac{t_2}{t_3}\frac{\partial t_7}{\partial t_4} \\ &=& \frac{-2t_5t_4}{t_6^2} + \frac{x_2t_1}{t_3t_6} + \frac{t_2}{t_3t_6}. \end{eqnarray}

Let's state this more generally. Suppose you break a function $f$ with parameters $x_1,\ldots,x_n$ into many intermediate terms $T = \{t_1,\ldots,t_n\}$ (as we did above). Let $\tau \subset T$ be the set of all $t \in T$ dependent on $t_i$. Then

$$\frac{\partial f}{\partial t_i} = \sum_{t \in \tau}\frac{\partial f}{\partial t}\frac{\partial t}{\partial t_i}. \tag{4}$$

It follows that

$$\frac{\partial f}{\partial x_i} = \sum_{t \in \tau}\frac{\partial f}{\partial t}\frac{\partial t}{\partial x_i} \tag{5}$$

where $\tau \subset T$ is the set of all $t \in T$ dependent on $x_i$. This is simply the multivariable chain rule. The equations above are important because it's how we "backpropagate" the derivative from the final term (i.e., $f(\ldots)$) to it's parameters $x_1,\ldots,x_n$. Let's implement this in Python.


Implementation

We create a class in Python with the following variables.

class Value:
  def __init__(self, val, reqgrad=False, grad_fn=None):
    self.val, self.reqgrad, self.grad_fn = val, reqgrad, grad_fn
    self.grad = 0

  # dunder methods here; will be explained below

  # backprop
  def back(self, grad=1):
    # reason for addition here is due to eq (4) and (5) 
    self.grad += grad
    if self.grad_fn: self.grad_fn()
  

a = Var(69, reqgrad=True)
b = Var(420, reqgrad=True)
c = a * b
c.back()  # should compute dc/da and dc/db
print(f"dc/da = {a.grad}, dc/db = {b.grad}")

The above code will be explained shortly. But we need to revise a few concepts first. In Python, integers are objects. Everything is an object. When you add (or some other operation) two integers, you're calling the __add__ dunder (short for double underscore) method of the integer object.

x = 5 + 2

# the above is the same as

x = (5).__add__(2)

# an integer object with value 5 is created
# __add__ method of the object is called w/ argument 2
# returns new integer object with value 5 + 2, i.e., 7
# try it in your terminal!

The cool thing is we can define the dunder methods of a class. For example

class Foo:
  def __init__(self, val):
    self.val = val

  # gets called whenever Foo instance multiplied by something else
  # something else gets passed as argument "other"
  def __mul__(self, other):
    # multiplying two Foo instances x and y returns
    # a new Foo instance whose val is x.val * y.val
    return Foo(self.val * other.val)


x = Foo(2)
y = Foo(4)
z = x.__mul__(y)  # allowed 'cause we've defined __mul__ for Foo
print(z, z.val)
z = x * y
print(z, z.val)  # same thing as above
z = x - y  # not allowed since __sub__ not defined for Foo

Understanding how a function call stack works is useful here.

a = Var(2, reqgrad=True)
b = Var(10, reqgrad=True)
c = Var(3, reqgrad=True)
y = (a * b) * c
# a * b is equiv to a.__mul__(b) which returns
# an intermediate Value instance t1
# t1 * c is equiv to t1.__mul__(c) which returns
# the resulting Value, which is stored in y

When two Value object are multiplied. The dunder method __mul__ returns a new object, i.e., a new intermediate term $t_i$. Suppose $t_i = t_jt_k$ and we know $\frac{\partial f}{\partial t_i}$. Then

\begin{eqnarray} \frac{\partial f}{\partial t_j} &=& t_k\frac{\partial f}{\partial t_i} \\ \frac{\partial f}{\partial t_k} &=& t_j\frac{\partial f}{\partial t_i} \end{eqnarray}

With this, we can define the backprop logic in the __mul__ method for Value.

class Value:
  def __init__(self, val, reqgrad=True, grad_fn=None):
    self.val, self.reqgrad, self.grad_fn = val, reqgrad, grad_fn
    self.grad = 0

  def __mul__(self, other):
    # define a closure
    # self and other return a new Value object (call it x)
    # grad is derivative of result w.r.t that object
    # during computation we dunno grad
    # this closure becomes grad_fn of x
    # when x.back is called, look @ back method and note what happens
    # x.grad_fn, i.e., the closure below is called
    # back method of x's arguments are called
    def grad_fn(grad):
      if self.reqgrad: self.back(grad * other.vals)
      if other.reqgrad: other.back(grad * self.val)
    return Value(self.val * other.val, self.reqgrad | other.reqgrad, grad_fn)

  def back(self, grad=1):
    # say, x = a * b and dy / dx = x.grad (by def)
    # x.grad_fn is called
    # x.grad_fn is a closure holding references to x's arguments a and b
    # note what happens in grad_fn
    # x.grad is passed into x.grad_fn as grad
    # dy / dx = x.grad (by def) so dy / da = x.grad * b
    # and dy / db = x.grad * a
    # a.back(grad * b) is called, and a.grad becomes a.grad + grad * b
    # vice versa for b
    self.grad += grad  # from eq (4) and (5)
    if self.grad_fn: self.grad_fn(self.grad)

For a unary operation like log

class Value:
  # ...
  
  def log(self):
    def grad_fn(grad):
      # x = a.log()
      # then dy / da = dy / dx * dx / da = grad * 1 / a 
      if self.reqgrad: self.back(grad * 1 / self.val)
    return Value(math.log(self.val), self.reqgrad, grad_fn)
    
  # ...

I'll leave the rest of the operations (-, /, ^, etc.) as an exercise. Refer to my implementation on GitHub if you want. The way it'll work for this function is

def f(x1: Value, x2: Value) -> Value:
  return (x1 * x2 * (x1 + x2)).log() / (x1 - x2) ** 2
  

x1, x2 = Var(2, reqgrad=True), Var(10, reqgrad=True)
y = f(x1, x2)
y.back()  # backprop to differentiate

print(f"df / dx1 = {x1.grad}, df / dx2 = {x2.grad}")

More resources on RMAD

Karpathy's micrograd
UofT lecture notes
CMU lecture notes (more technical)
YT: Understanding Automatic Differentiation via Computation Graphs
What is Automatic Differentiation?