Deep learning notes 07: PonderNet - Learn when to stop thinking
This post if from a series of quick notes written primarily for personal usage while reading random ML/SWE/CS papers. As such they might be incomprehensible and/or flat out wrong.
PonderNet - Learn when to stop thinking
- Recurrent(ly run) network that can decide when to stop computation
- In standard NN the amount of computation grows with size of input/output and/or is static; not with problem complexity
- End-to-end trained to compromise between: computation cost (# of steps), training prediction accuracy, and generalization
- Halting node: predicts probability of halting on conditional of not halting before
- The rest of the network can be any architecture (rnn, cnn, transformer, …, capsule network, …)
- Input
x
;x
processed to hidden stateh_i
; processed vias(...)
function;(h_i+1, y_i, λ_i) = s(h_i)
- Each steps returns next hidden state (
h_i+1
), output (y_i
), and probability of stopping (λ_i
) - Probability to stop at step n:
p_n = λ_n * TT_1..n-1 (1- λ_i)
- Each steps returns next hidden state (
- At inference
λ
is used probabilistically (i.e. the probability is sampled) - Training loop:
- Input
x
into encoder, geth_0
, … unroll the network for n steps regardless ofλ
- Consider all outputs at the same time;
loss = p_1*L(y_1)+p_2*L(y_2)+...+p_n*L(y_n)
- -> Possibly unstable -> two goals: make
y_i
better or makep_i
smaller - Regularization for
KL(p_i || geometricDistirbution(λp))
-> forces lambdas to be similar to hyperparameter
- Input
- Contrast vs ACT:
- Considers the output a weighted average of outputs:
loss = L(p_1 * y_1 + ... + p_n * y_n)
- Early results need to be compatible with later results
- Less dynamic; needs more steps in experiments, worse in extrapolation
- Pondernet correctly needs more steps for more complex problems
- Considers the output a weighted average of outputs:
Written
by Petr Houška
on