Deep learning notes 07: PonderNet - Learn when to stop thinking

This post if from a series of quick notes written primarily for personal usage while reading random ML/SWE/CS papers. As such they might be incomprehensible and/or flat out wrong.

PonderNet - Learn when to stop thinking

• Recurrent(ly run) network that can decide when to stop computation
• In standard NN the amount of computation grows with size of input/output and/or is static; not with problem complexity
• End-to-end trained to compromise between: computation cost (# of steps), training prediction accuracy, and generalization
• Halting node: predicts probability of halting on conditional of not halting before
• The rest of the network can be any architecture (rnn, cnn, transformer, …, capsule network, …)
• Input x; x processed to hidden state h_i; processed via s(...) function; (h_i+1, y_i, λ_i) = s(h_i)
• Each steps returns next hidden state (h_i+1), output (y_i), and probability of stopping (λ_i)
• Probability to stop at step n: p_n = λ_n * TT_1..n-1 (1- λ_i)
• At inference λ is used probabilistically (i.e. the probability is sampled)
• Training loop:
• Input x into encoder, get h_0, … unroll the network for n steps regardless of λ
• Consider all outputs at the same time; loss = p_1*L(y_1)+p_2*L(y_2)+...+p_n*L(y_n)
• -> Possibly unstable -> two goals: make y_i better or make p_i smaller
• Regularization for KL(p_i || geometricDistirbution(λp)) -> forces lambdas to be similar to hyperparameter
• Contrast vs ACT:
• Considers the output a weighted average of outputs: loss = L(p_1 * y_1 + ... + p_n * y_n)
• Early results need to be compatible with later results
• Less dynamic; needs more steps in experiments, worse in extrapolation
• Pondernet correctly needs more steps for more complex problems
Written by on