Fast weights: attend to the recent past
This post if from a series of notes written for personal usage while reading random ML/SWE/CS papers. The notes weren’t originally intended for the eyes of other people and therefore might be incomprehensible and/or flat out wrong.
Paper in question: 1610.06258. Relevant lectures here and here.
- Traditionally D/RNN have two types of memory:
- Standard weights
W
: long term memory of the whole dataset - Hidden state
h(t)
: maintained by active neurons / directly storing activ., limited, very immediate- Remembered things heavily influence currently processed stuff, limited capacity
- Standard weights
- New intermediate “fast weights”
A(t)
: weights/synapses but faster learning, rapid decay of temp. information- Higher capacity associative “network” that stores past hidden states (e.g. Hopfield network)
- New (hidden) state is a combination of traditional new RNN state and its repeated modulation with fast weights
- Two steps: preliminary vector
h_0(t+1)=f(Wh(t)+Cx(t))
ands:1..S
steps of inner loop withh_S(t+1)==h(t+1)
h_s+1(t+1)=f([Wh(t)+Cx(t)]+A(t)*h_s(t+1))
- Multiple steps
s:1..S
allow the new state to settle;[…]
the same for all, preliminary vector withoutf(..)
- Multiple steps
- Fast weights
A(t+1)
updated withh(t+1)
at the end of timestamp e.g. using Hopfield rule + heavy decay - Backprop can go (even) through fast weights (doesn’t update them) -> network learns to work it them
- Two steps: preliminary vector
- Problem with minibatches: different fast weights for each sequence in a batch -> expl. attention over past hidden states
- Unclear details how that solves minibatch problem but less comp. intensive:
k\*n
vsn^2
A(t)*h_s(t+1) ~= sum_i=1..t λ^(t-i) h(i)*<h(i),h_s(t+1)>
- Non-parametrized attention with strength ~ scalar product between the past state and current hidden state
- Unclear details how that solves minibatch problem but less comp. intensive:
- Notes:
- Interpretation: Attracts each new state towards recent hidden states
- Benefit: frees hidden state to learn good representation for final classification
- Needs fast weights output layer normalization
Written
by Petr Houška
on