Escaping Saddle Points Faster with Stochastic Momentum

By Jun-Kun Wang, Chi-Heng Lin, and Jacob Abernethy

SGD with stochastic momentum (see Figure 1 below) has been the de facto training algorithm in nonconvex optimization and deep learning. It has been widely adopted for training neural nets in various applications. Modern techniques in computer vision (e.g.[1,2]), speech recognition (e.g. [3]), natural language processing (e.g. [4]), and reinforcement learning (e.g. [5] ) use SGD with stochastic momentum to train models. To be precise, during the training, the algorithm is applied to solve Screen Shot 2020-02-18 at 10.46.02 AM,  where the underlying loss function Screen Shot 2020-02-18 at 10.46.33 AMmeasures the predicted performance of a neural net with weights won a mini-batch of samples Screen Shot 2020-02-18 at 10.46.39 AM, which is typically non-convex.

Screen Shot 2020-02-18 at 10.44.31 AM

Despite the wide use of stochastic momentum in practice, a theoretical justification for its superiority has remained elusive, as has any mathematical guidelines for actually setting the momentum parameter Screen Shot 2020-02-18 at 10.48.15 AM—it has been observed that large values (e.g. Screen Shot 2020-02-18 at 10.48.15 AM =0.9) work well in practice. We note that PyTorch, one of the most popular deep learning toolkits, recommends a high value of momentum parameter Screen Shot 2020-02-18 at 10.48.15 AM=0.9 for training neural nets in its tutorial page.  

Screen Shot 2020-02-18 at 10.44.31 AM

At the same time, a widely-observed empirical phenomenon is that in training deep networks stochastic momentum appears to significantly improve convergence time, variants of it have flourished in the development of other popular update methods, e.g. ADAM ([6]), AMSGrad ([7]), etc. Yet theoretical justification for the use of stochastic momentum has remained a significant open question.

Our work proposes an answer: stochastic momentum improves deep network training because it modifies SGD to escape saddle points faster and, consequently, leads to faster convergence. For practitioners, our result provides guidance for setting the value of the momentum parameter, which is important for speeding up deep learning. For theorists, our work sheds light on understanding the interaction of optimization and deep learning.

More Details

 

Screen Shot 2020-02-18 at 10.50.41 AM
Figure 3. Illustration of a (strict) saddle point region, indicated by the star sign. The region is where the gradient is small but has an escape direction, which corresponds to the eigenvector of the smallest eigenvalue of the Hessian, min(2 f(wt) ).  The figure is stolen from a great blog article by Chi Jin and Michael Jordan.

Let us provide some high-level intuition about the benefit of stochastic momentum with respect to escaping saddle points. In an iterative update scheme, at some time t the iterate Screen Shot 2020-02-18 at 10.52.13 AM can enter a saddle point region (see Figure 3), that is a place where Hessian 2  has a non-trivial negative eigenvalue, say Screen Shot 2020-02-18 at 10.53.19 AM and the gradientScreen Shot 2020-02-18 at 10.52.44 AM is small in norm, say Screen Shot 2020-02-18 at 10.55.41 AM. The challenge here is that gradient updates may drift only very slowly away from the saddle point, and may not escape this region. On the other hand, if the iterates were to move in one particular direction, namely along the direction of the smallest eigenvector of Screen Shot 2020-02-18 at 10.52.44 AM, then a fast escape is guaranteed under certain constraints on the step size 𝞰. While the negative eigenvector could be computed directly, this 2nd-order method is prohibitively expensive due to millions of parameters in state-of-the-art neural nets and hence we typically aim to rely on first-order methods (e.g. SGD with stochastic momentum).

Screen Shot 2020-02-18 at 10.59.00 AM

Now let us conduct a thought experiment. Denote the escape direction (i.e. the negative curvature direction) Screen Shot 2020-02-18 at 10.59.41 AM. Assume that at some iterate Screen Shot 2020-02-18 at 10.52.13 AM, we have momentum Screen Shot 2020-02-18 at 10.59.56 AM which possesses a significant correlation with the negative curvature direction Screen Shot 2020-02-18 at 10.59.41 AM, then on successive rounds Screen Shot 2020-02-18 at 11.00.02 AM is quite close to Screen Shot 2020-02-18 at 11.00.10 AM is quite close to Screen Shot 2020-02-18 at 11.00.17 AM, and so forth; see Figure 4 for an example. This provides an intuitive perspective on how momentum might help accelerate the escape process. Yet one might ask: does this procedure provably contribute to the escape process? And, if so, what is the aggregate performance improvement of the momentum? We answer the first question in the affirmative, and we answer the second question with the following result.

Screen Shot 2020-02-18 at 11.02.44 AM

Experiment

Let us consider a toy example to demonstrate that SGD with stochastic momentum escapes saddle points faster with higher values of Screen Shot 2020-02-18 at 10.48.15 AM . We follow the works of [8] for the experiment, which considers solving Screen Shot 2020-02-18 at 11.06.50 AM with an embedded saddle given by the matrix A:= diag([1, -0.1]) and stochastic gaussian perturbations given by Screen Shot 2020-02-18 at 11.07.24 AM. Note that the small variance in the second component provides lower projection of the gradient in the escape direction. The experiment considers setting the initial iterate w0 at the saddle point (i.e. the origin), and applying SGD with momentum to solve the optimization problem with different values of Screen Shot 2020-02-18 at 10.48.15 AM .  Figure 5 clearly demonstrates that the higher the momentum parameter Screen Shot 2020-02-18 at 10.48.15 AM, the faster the escape process.

Screen Shot 2020-02-18 at 11.08.43 AM

Reference:

[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. Conference on Computer Vision and Pattern Recognition (CVPR), 2016

[2] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. Imagenet classification with deep convolutional neural networks. NIPS, 2012

[3] Ashish Vaswani, Noam Shazeer, Niki Parmar, and et al. Attention is all you need. NIPS, 2017.

[4] David Silver, Julian Schrittwieser, Karen Simonyan, and et al. Mastering the game of go without human knowledge. Nature, 2017.

[5] Geoffrey Hinton, Li Deng, Dong Yu, and et al. Deep Neural Networks for Acoustic Modeling in Speech Recognition: The Shared Views of Four Research Groups. IEEE Signal Processing Magazine 2012

[6] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. ICLR, 2015.

[7] Sashank J. Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of adam and beyond. ICLR,2018.

[8] Sashank Reddi, Manzil Zaheer, Suvrit Sra, Barnabas Poczos, Francis Bach, Ruslan Salakhutdinov, and Alex Smola. A generic approach for escaping saddle points. AISTATS, 2018.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.