NFNet: High-Performance
Large-Scale Image Recognition
Without Normalization
Alexey Zinoviev, JetBrains
Slide 2
Slide 2 text
Bio
● Java & Kotlin developer
● Distributed ML enthusiast
● Apache Ignite PMC
● TensorFlow Contributor
● ML engineer at JetBrains
● Happy father and husband
● https://github.com/zaleslaw
● https://twitter.com/zaleslaw
Slide 3
Slide 3 text
NFNets: top-1 accuracy vs training latency
Slide 4
Slide 4 text
Evolution of Image Recognition models
Slide 5
Slide 5 text
BatchNorm + Skip Connections Epoch
NFNet-F4+
Slide 6
Slide 6 text
Top-1 Accuracy 2021
Slide 7
Slide 7 text
Architecture innovations: 2015-2019
Slide 8
Slide 8 text
Residual Block
Slide 9
Slide 9 text
Bottleneck Residual Block
Slide 10
Slide 10 text
Bottleneck Residual Block
Slide 11
Slide 11 text
Depthwise Separable Convolution
Slide 12
Slide 12 text
Inception Module
Slide 13
Slide 13 text
Multi-residual block: ResNeXt
Slide 14
Slide 14 text
Shared source skip connections
Slide 15
Slide 15 text
Dense blocks in DenseNet (DenseNet != MLP)
Slide 16
Slide 16 text
AutoML Era
Slide 17
Slide 17 text
New principle in AutoML era
Give us your computational resource limit and
we scale our baseline model for you
Slide 18
Slide 18 text
EfficientNet: scale
Slide 19
Slide 19 text
3D Pareto: FLOPS, Number of params and Top-1 Acc
Slide 20
Slide 20 text
HPO task
Slide 21
Slide 21 text
Bilevel optimization problem
Slide 22
Slide 22 text
NAS (neural architecture search)
Slide 23
Slide 23 text
Search Space skeleton
Slide 24
Slide 24 text
Controller implementation (RNN, isn’t it?)
Slide 25
Slide 25 text
NASNet cells designed by AutoML algorithm
Slide 26
Slide 26 text
Evolutionary algorithms
Slide 27
Slide 27 text
Evolutionary AutoML
Slide 28
Slide 28 text
Evolution vs RL vs Random Search
Slide 29
Slide 29 text
AmoebaNet: pinnacle of evolution
Slide 30
Slide 30 text
No content
Slide 31
Slide 31 text
Batch Normalization
Slide 32
Slide 32 text
Internal Covariate Shift
Slide 33
Slide 33 text
Batch Norm algorithm
Slide 34
Slide 34 text
Bad parts
● Batch normalization is expensive
● Batch normalization breaks the assumption of data independence
● Introduces a lot of extra hyper-parameters that need further fine-tuning
● Causes a lot of implementation errors in distributed training
● Requires a specific “training” and “inference” mode in frameworks
Slide 35
Slide 35 text
The philosophy of this paper
Identify the origin of BatchNorm’s benefits and replicate these benefits in
BatchNorm-free Neural Networks
Slide 36
Slide 36 text
Good parts
● Batch normalization downscales the residual branch
● Batch normalization eliminates mean-shift (in ReLU networks)
● Batch normalization has a regularizing effect
● Batch normalization allows efficient large-batch training
Slide 37
Slide 37 text
Early Free Batch Norm architects
Slide 38
Slide 38 text
Residual Branch Downscaling effect: SkipInit
Batch Normalization Biases Deep Residual Networks Towards Shallow Paths
Slide 39
Slide 39 text
Removing Mean Shift: Scaled Weight Standardization
Characterizing signal propagation to close the performance gap in
unnormalized ResNets
Slide 40
Slide 40 text
NFNet improvements
Slide 41
Slide 41 text
Improved SkipInit
Slide 42
Slide 42 text
Changed Scaled Weight Standardization
Slide 43
Slide 43 text
Regularization: Dropout
Dropout
CNNs
BatchNorm
me
Slide 44
Slide 44 text
Regularization: Stochastic Depth
Slide 45
Slide 45 text
Regularization: Stochastic Depth
Slide 46
Slide 46 text
How to train on larger batches?
Slide 47
Slide 47 text
Gradient Clipping
Slide 48
Slide 48 text
Intuition about Gradient Clipping
Parameter updates should be small relative to the magnitude of the weight
Slide 49
Slide 49 text
Adaptive Gradient Clipping
Slide 50
Slide 50 text
Not enough Top-1 Accuracy!!!
Even with Adaptive Gradient Clipping, and the modified residual branch and
convolutions, normalizer-free networks still could not surpass the accuracies of
EfficientNet.
Slide 51
Slide 51 text
New SOTA model family
Slide 52
Slide 52 text
Architecture optimization for improved accuracy and
training speed
● SE-ResNeXt-D as a baseline
● Fixed group width (specific for the ResNeXt architecture)
● Depth Scaling pattern was changed (from very specific to very simple)
● Width pattern was changed too
Slide 53
Slide 53 text
Bottleneck block design
Slide 54
Slide 54 text
No content
Slide 55
Slide 55 text
No content
Slide 56
Slide 56 text
The whole NFNet family
Slide 57
Slide 57 text
Implementation
● On JAX from authors
● Collab to play with
● PyTorch with weights
● Yet another PyTorch
● Very good PyTorch (clear code)
● Adaptive Gradient Clipping example
● Broken Keras example (could be a good entry point)
● Raw TF implementation
Slide 58
Slide 58 text
SAM: Sharpness Aware Minimization
Slide 59
Slide 59 text
SAM: Sharpness Aware Minimization
Slide 60
Slide 60 text
SAM: Two backprops and approximate grads
Slide 61
Slide 61 text
Accelerating Sharpness-Aware Minimization
The idea is to use 20% of batch on SAM step only.
Summary
1. Downscales residual branch
2. Enables large batch training
3. Implicit regularization
4. Prevents mean-shift
1. NF-strategy
2. Adaptive gradient clipping
3. Explicit regularization
4. Scaled Weight Standardization
● NFNets were new SOTA during a few months
● NFResNet >= BNResNet
● NFNet training >>> EfficientNet training