The Simple Essence of Automatic Differentiation
This post if from a series of notes written for personal usage while reading random ML/SWE/CS papers/lectures. The notes weren’t originally intended for the eyes of other people and therefore might be incomprehensible and/or flat out wrong. These notes might be especially ish as my knowledge of both category theory and lambda calculus is relatively limited.
- Don’t look at backprop in terms of graphs, treat it as algebraic composition over functions
- Graphs can be good for visualization, but AD has nothing to do with graphs
- Problem with graphs -> hard to deal with, trees easier but can be exp. big
- Derivative of a function
f::(a->b)
at a certain point is a linear mapf'::(a-:b)
- Can’t compose derivative operator:
D::(a->b)->(a->(a-:b))
- When composed
D (g o f) a = D g (f a) o D f a
it also needsf a
i.e.b
- When composed
- Solution:
D'::(a->b) -> (a->(b * (a-:b)))
, doesn’t produce just derivative butf(a) * f'::b * (a-:b)
- Now for
D' (g o f)
we only need(D' g) o (D' f)
-> easily composable
- Now for
- Can’t compose derivative operator:
- Compile the composition to various categories: graph repres., functions composition, derivative, …
- Category: identity ::
a->a
, composition (o
) ::(b->c)->(a->b) -> (a->c)
- Cartesian category: notion of Cartesian multipl.: pairs (
*
), select left, select right
- Category: identity ::
- Automatic differentiation
- Require D’ to preserve Cartesian category structure -> solve -> instantiation that gives us automatic diff.
- Derivatives of complex stuff happens “automatically” through composition of D’
- Due to composition -> a lot of computation sharing -> efficient
- Problem with sum types: not continuous, ?Church encoding
- Generally three ways of AD:
- Numerical approximation: terrible performance and accuracy
- Symbolic: ~to calculus at high-school
- AD: Chain-rule based, …, ~symbolic done by compiler
- Linear maps could be replaced with other conforming functions -> generalized AD
- Useful when we need to extract data representation (gradient, …)
- Replace linear maps with matrices:
- Enables efficient extraction of e.g. gradient without having to run base vectors matrix (domainDim^2) through lin. map
- Every composition is matrix multiplication -> associative -> impacts performance
- Not easy to figure out the best association but depends only on dims -> types
- If domain » co-domain (e.g.: ML) reverse mode AD (all left) is usually best
- Left associating composition
- Corresponds to backprop
- CPS-like category: represent
a->b
by(b->r) -> (a->r)
~ continuation passing style - Given
f
, we’ll interpretf
as something that composes withf
:o f
- Results in left-composition (need to initialize with identity)
- Reverse mode AD: generalized AD category parametrized by continuation transform version of matrixes
Written
by Petr Houška
on