Deep learning notes 07: PonderNet - Learn when to stop thinking
This post if from a series of quick notes written primarily for personal usage while reading random ML/SWE/CS papers. As such they might be incomprehensible and/or flat out wrong.
PonderNet - Learn when to stop thinking
- Recurrent(ly run) network that can decide when to stop computation
- In standard NN the amount of computation grows with size of input/output and/or is static; not with problem complexity
- End-to-end trained to compromise between: computation cost (# of steps), training prediction accuracy, and generalization
- Halting node: predicts probability of halting on conditional of not halting before
- The rest of the network can be any architecture (rnn, cnn, transformer, …, capsule network, …)
- Input
x;xprocessed to hidden stateh_i; processed vias(...)function;(h_i+1, y_i, λ_i) = s(h_i)- Each steps returns next hidden state (
h_i+1), output (y_i), and probability of stopping (λ_i) - Probability to stop at step n:
p_n = λ_n * TT_1..n-1 (1- λ_i)
- Each steps returns next hidden state (
- At inference
λis used probabilistically (i.e. the probability is sampled) - Training loop:
- Input
xinto encoder, geth_0, … unroll the network for n steps regardless ofλ - Consider all outputs at the same time;
loss = p_1*L(y_1)+p_2*L(y_2)+...+p_n*L(y_n) - -> Possibly unstable -> two goals: make
y_ibetter or makep_ismaller - Regularization for
KL(p_i || geometricDistirbution(λp))-> forces lambdas to be similar to hyperparameter
- Input
- Contrast vs ACT:
- Considers the output a weighted average of outputs:
loss = L(p_1 * y_1 + ... + p_n * y_n) - Early results need to be compatible with later results
- Less dynamic; needs more steps in experiments, worse in extrapolation
- Pondernet correctly needs more steps for more complex problems
- Considers the output a weighted average of outputs:
Written
by Petr Houška
on
