Deep learning notes 07: PonderNet - Learn when to stop thinking

#papers

This post if from a series of quick notes written primarily for personal usage while reading random ML/SWE/CS papers. As such they might be incomprehensible and/or flat out wrong.

PonderNet - Learn when to stop thinking

  • Recurrent(ly run) network that can decide when to stop computation
  • In standard NN the amount of computation grows with size of input/output and/or is static; not with problem complexity
  • End-to-end trained to compromise between: computation cost (# of steps), training prediction accuracy, and generalization
    • Halting node: predicts probability of halting on conditional of not halting before
    • The rest of the network can be any architecture (rnn, cnn, transformer, …, capsule network, …)
  • Input x; x processed to hidden state h_i; processed via s(...) function; (h_i+1, y_i, λ_i) = s(h_i)
    • Each steps returns next hidden state (h_i+1), output (y_i), and probability of stopping (λ_i)
    • Probability to stop at step n: p_n = λ_n * TT_1..n-1 (1- λ_i)
  • At inference λ is used probabilistically (i.e. the probability is sampled)
  • Training loop:
    • Input x into encoder, get h_0, … unroll the network for n steps regardless of λ
    • Consider all outputs at the same time; loss = p_1*L(y_1)+p_2*L(y_2)+...+p_n*L(y_n)
    • -> Possibly unstable -> two goals: make y_i better or make p_i smaller
    • Regularization for KL(p_i || geometricDistirbution(λp)) -> forces lambdas to be similar to hyperparameter
  • Contrast vs ACT:
    • Considers the output a weighted average of outputs: loss = L(p_1 * y_1 + ... + p_n * y_n)
    • Early results need to be compatible with later results
    • Less dynamic; needs more steps in experiments, worse in extrapolation
    • Pondernet correctly needs more steps for more complex problems
Written by on