|Context||NIPS 2015 Oral|
Nowadays we try to learn neural networks with very large output spaces, such as natural language models where the next word must be predicted from a large vocabulary (hundreds of thousands of words). If the representation prior to the output is of reasonable size, the computation can still be expensive because it involves a matrix of quadratic size in the output dimensionality and the prior representation dimensionality. Currently employed approaches include changing the probability model (e.g. hierarchical softmax) or sampling a small subset of the out dimensions for each out. Relative to computing the whole output, both approximations are crude. This work shows it's possible that computing the full exact gradient is possible in $O(d^2)$ time, where $d$ is the representation dimension, not the output dimension, for a restricted set of loss functions.
To represent a $K$-sparse vector in $D$ dimensional space, you only need to store the indices of the non-zero elements and the values of those non-zero entries. Multiplication of such a vector by a dense $d \times D$ matrix is $O(Kd)$ instead of $O(Dd)$. Consider a network whose input and output are both $D$ (very high) dimensional but both $K$-sparse, and with intermediate layers which are only $d$ dimensional, which is trained with the square loss. Again, $K \ll d \ll D$. For forward propagation, the input-to-hidden weight matrix is $O(Dd)$ or $O(Kd)$ when sparse multiplications are used, all hidden matrices are $O(d^2)$, and the output is $O(Dd)$ but there is guarantee that the hidden representation is dense, so this final matrix is still prohibitively expensive. Similarly, in backpropagation, the gradient calculation and update of the output weight matrix is still $O(Dd)$; the hidden gradients and updates are all $O(d^2)$, and the input gradient and weight matrix can be $O(Kd)$. However, it's possible to compute the loss $L$, the gradient of the loss with respect to the last hidden layer, and the exact same gradient update for the output weight matrix in $O(d^2)$ time! With the following tricks:
First, we keep an up-to-date $d \times d$ weight matrix $Q = W^\top W$ where $W$ is the output matrix. For squared error loss, this makes it possible to compute the loss without ever computing the output, resulting in an $O(d^2)$ calculation. Similarly, for computing the gradient of the loss, the same trick results in an $O(d^2)$ calculation. Second, the update of $W$ is a non-sparse rank-one update, which updates all of $W$, but can be represented explicitly as a factorization into a $D \times d$ and $d \times d$ matrix where the $d \times d$ matrix is updated completely and only $K$ rows are updated of the $D \times d$ matrix, provided that its inverse ($O(d^2)$) is kept up-to-date. This latter trick is not the same as doing ordinary backprop on the factorized matrices, but it does do the exact same update. These tricks produce a full, nine-step algorithm which is altogether about$O(12d^2)$ which produces a nice computational benefit as well as an important lowering of memory access. It can be extended to the minibatch case, and is applicable to the spherical loss family, but is not applicable to regular log softmax. Compared to naive calculation, the computational cost doesn't grow compared to the amount that it grows for the unfactorized version, resulting in two to three orders-of-magnitude speedup in practice, computing the exact same thing.