The simplest way to compute the derivative of a function comes from the definition of a partial derivative
For some function
# define some function
def f(a, b, c):
return 2 * a * b * c * (a ** 2 + b ** 2 + c ** 2)
to compute the derivative at some point (a, b, c)
# smaller h => more accuracy
h = 0.0001
# definition of derivative
df_over_da = (f(a + h, b, c) - f(a, b, c)) / h
print(df_over_da)
This method is very slow because you need to evaluate the function $O(P)$ times to compute the gradient (where $P$ is the # parameters). What if the function has billions of parameters (neural nets)? What if you need to compute the gradient at billions of points?
The basic idea of automatic differentiation is to break down a complex function into a series of primitive operations and use the chain rule to compute the derivatives recursively. For example, consider the function
Our goal is to find the derivative of $f$ w.r.t $x_1$. In Python code, we'll store the output of the function for inputs
x1
and x2
in some object y
. We then call some method of y
that'll compute
the derivative of y
w.r.t it's parameters (here, x1
and x2
). How do we do this?
We start by breaking the above down into simple terms like so
A "term" should be of one of the following forms
where $\circ$ is some binary operation (e.g., adding), $u$ is some unary operation (e.g., log), and $i$, $j$, $k$ may or may not be equal to one another. We start from $t_7$ and work our way to lesser terms using the chain rule
Consider $(1)$, $(2)$, and $(3)$. The neat thing is that
We keep subsituting the values of the partial derivative of $t_7$ w.r.t to the intermediate terms until the final expression is only dependent on $t_1,\ldots,t_n$ and $x_1,\ldots,x_n$
Let's state this more generally. Suppose you break a function $f$ with parameters $x_1,\ldots,x_n$ into many intermediate terms $T = \{t_1,\ldots,t_n\}$ (as we did above). Let $\tau \subset T$ be the set of all $t \in T$ dependent on $t_i$. Then
It follows that
where $\tau \subset T$ is the set of all $t \in T$ dependent on $x_i$. This is simply the multivariable chain rule. The equations above are important because it's how we "backpropagate" the derivative from the final term (i.e., $f(\ldots)$) to it's parameters $x_1,\ldots,x_n$. Let's implement this in Python.
We create a class in Python with the following variables.
class Value:
def __init__(self, val, reqgrad=False, grad_fn=None):
self.val, self.reqgrad, self.grad_fn = val, reqgrad, grad_fn
self.grad = 0
# dunder methods here; will be explained below
# backprop
def back(self, grad=1):
# reason for addition here is due to eq (4) and (5)
self.grad += grad
if self.grad_fn: self.grad_fn()
a = Var(69, reqgrad=True)
b = Var(420, reqgrad=True)
c = a * b
c.back() # should compute dc/da and dc/db
print(f"dc/da = {a.grad}, dc/db = {b.grad}")
The above code will be explained shortly. But we need to revise a few concepts first. In Python, integers are objects. Everything is an
object. When you add (or some other operation) two integers, you're calling the __add__
dunder (short for double
underscore) method of the integer object.
x = 5 + 2
# the above is the same as
x = (5).__add__(2)
# an integer object with value 5 is created
# __add__ method of the object is called w/ argument 2
# returns new integer object with value 5 + 2, i.e., 7
# try it in your terminal!
The cool thing is we can define the dunder methods of a class. For example
class Foo:
def __init__(self, val):
self.val = val
# gets called whenever Foo instance multiplied by something else
# something else gets passed as argument "other"
def __mul__(self, other):
# multiplying two Foo instances x and y returns
# a new Foo instance whose val is x.val * y.val
return Foo(self.val * other.val)
x = Foo(2)
y = Foo(4)
z = x.__mul__(y) # allowed 'cause we've defined __mul__ for Foo
print(z, z.val)
z = x * y
print(z, z.val) # same thing as above
z = x - y # not allowed since __sub__ not defined for Foo
Understanding how a function call stack works is useful here.
a = Var(2, reqgrad=True)
b = Var(10, reqgrad=True)
c = Var(3, reqgrad=True)
y = (a * b) * c
# a * b is equiv to a.__mul__(b) which returns
# an intermediate Value instance t1
# t1 * c is equiv to t1.__mul__(c) which returns
# the resulting Value, which is stored in y
When two Value
object are multiplied. The dunder method __mul__
returns a new object, i.e.,
a new intermediate term $t_i$. Suppose $t_i = t_jt_k$ and we know $\frac{\partial f}{\partial t_i}$. Then
With this, we can define the backprop logic in the __mul__
method for Value
.
class Value:
def __init__(self, val, reqgrad=True, grad_fn=None):
self.val, self.reqgrad, self.grad_fn = val, reqgrad, grad_fn
self.grad = 0
def __mul__(self, other):
# define a closure
# self and other return a new Value object (call it x)
# grad is derivative of result w.r.t that object
# during computation we dunno grad
# this closure becomes grad_fn of x
# when x.back is called, look @ back method and note what happens
# x.grad_fn, i.e., the closure below is called
# back method of x's arguments are called
def grad_fn(grad):
if self.reqgrad: self.back(grad * other.vals)
if other.reqgrad: other.back(grad * self.val)
return Value(self.val * other.val, self.reqgrad | other.reqgrad, grad_fn)
def back(self, grad=1):
# say, x = a * b and dy / dx = x.grad (by def)
# x.grad_fn is called
# x.grad_fn is a closure holding references to x's arguments a and b
# note what happens in grad_fn
# x.grad is passed into x.grad_fn as grad
# dy / dx = x.grad (by def) so dy / da = x.grad * b
# and dy / db = x.grad * a
# a.back(grad * b) is called, and a.grad becomes a.grad + grad * b
# vice versa for b
self.grad += grad # from eq (4) and (5)
if self.grad_fn: self.grad_fn(self.grad)
For a unary operation like log
class Value:
# ...
def log(self):
def grad_fn(grad):
# x = a.log()
# then dy / da = dy / dx * dx / da = grad * 1 / a
if self.reqgrad: self.back(grad * 1 / self.val)
return Value(math.log(self.val), self.reqgrad, grad_fn)
# ...
I'll leave the rest of the operations (-, /, ^, etc.) as an exercise. Refer to my implementation on GitHub if you want. The way it'll work for this function is
def f(x1: Value, x2: Value) -> Value:
return (x1 * x2 * (x1 + x2)).log() / (x1 - x2) ** 2
x1, x2 = Var(2, reqgrad=True), Var(10, reqgrad=True)
y = f(x1, x2)
y.back() # backprop to differentiate
print(f"df / dx1 = {x1.grad}, df / dx2 = {x2.grad}")
Karpathy's micrograd
UofT lecture notes
CMU lecture notes (more technical)
YT: Understanding Automatic Differentiation via Computation Graphs
What is Automatic Differentiation?