Advanced Efficient Techniques
=============================
In addition to standard model compression methods, some advanced
approaches are being developed to accelerate the decoding process of the
large models. These methods include generating specific tokens using
smaller models and the ability to generate multiple tokens in a single
step, resulting in accelerating the decoding process. Furthermore, there
are techniques that utilize the memory hierarchy for high throughput
computation, aiming to decrease memory I/O, and as a result, be more
efficient.
Speculative Decoding
--------------------
Speculative decoding is a strategy to speed up the decoding process,
based on insights provided by Leviathan et al. [@leviathan2023fast].
1. Complex modeling tasks frequently encompass simpler subtasks that can
be effectively approximated using more efficient models.
2. By combining speculative execution with a unique sampling approach,
it is possible to accelerate exact decoding from larger models. This
is achieved by processing them with the outputs from the
approximation models in parallel.
Figure :numref:`ch-deploy/sd` is a brief overview of Speculative
Decoding. It involves initially generating a series of tokens using a
draft model, which is a smaller and less complex model. These generated
tokens are then verified in parallel with the target model, which is a
larger model. The tokens that are finally executed in the output are
those that are accepted by the target model from the initial draft
tokens. Additionally, if rejection occurs, one more token is resampled
and generated from the adjusted distribution. If there is no rejection,
an extra token is generated by the target model using the draft tokens
as context.
.. raw:: html
.. container:: center
.. raw:: html
Speculative Decoding Overview
.. raw:: html
.. raw:: html
To elaborate, the process begins with the draft model generating a
series of :math:`\gamma` tokens, denoted as
:math:`x_1, x_2, ..., x_{\gamma}`. Subsequently, it preserves the
distributions :math:`q_{1}(x), q_{2}(x), ..., q_{\gamma}(x)` of these
tokens for future verification by the target model. These :math:`\gamma`
tokens are then inputted into the target model in parallel to calculate
the logits for the respective token combinations
:math:`p_{1}(x), p_{2}(x), ..., p_{\gamma+1}(x)`, derived from
:math:`M_{\text{target}}(\text{prefix} + [x_1 + ... + x_{\gamma}])`. If
the condition :math:`q(x) < p(x)` is met, the token is retained. In
contrast, if not met, the token faces a rejection chance of
:math:`1 - \frac{p(x)}{q(x)}`, following which it is reselected from an
adjusted distribution:
.. math:: p'(x) = norm(max(0, p(x) - q(x)))
:eqlabel:``equ:sd_adjusted``
In the paper [@leviathan2023fast], Leviathan et al. have proved the
correctness of this adjusted distribution for resampling.
Under the assumption that the execution time for a single step of the
Target model is denoted as :math:`T`, and that of the draft model as
:math:`cT`, where :math:`0
.. container:: center
.. raw:: html
Memory Hierarchy Overview
.. raw:: html
.. raw:: html
.. raw:: html
.. container:: center
.. raw:: html
FlashAttention Overview with Two Blocks
.. raw:: html
.. raw:: html
**Recomputation**:
Standard attention requires :math:`O(N^2)` memory to store intermediate
matrices **S** and **P** for gradient computation w.r.t. **Q, K, V** in
the backward pass. For FlashAttention, **S** and **P** can be recomputed
with the HBM-stored :math:`s(x)`, :math:`m(x)` and **O** in SRAM easily.
Therefore, only :math:`O(N)` memory is required. Furthermore,
FlashAttention has fewer HBM accesses than Standard Attention which
results in faster runtime [@dao2022flashattention].
The standard FlashAttention implementation doesn’t eliminate the
redundant computation of zero elements within the attention mechanism.
To address this, a mask is incorporated in FlashAttention to focus
computation exclusively on non-zero elements. Termed as Block-Sparse
FlashAttention, this approach is also discussed in
[@dao2022flashattention]. By using sparsity, Block-Sparse FlashAttention
effectively reduces the larger component of the I/O complexity, leading
to a direct improvement in performance.
However, FlashAttention has not been fully optimized. Dao noted that its
inefficiency stems from suboptimal work distribution among various
thread blocks and warps on the GPU. This leads to either low occupancy
or unnecessary shared memory reads and writes. Thus, Dao proposed
**FlashAttention-2** [@dao2023flashattention2] which has better
parallelism and work partitioning.
FlashAttention-2 includes several tweaks to reduce the non-matmul
operations.
1. Remain output **O** blocks un-scaled until the very end of the loop.
2. Instead of saving both :math:`s(x)` and :math:`m(x)` in HBM, save
:math:`logsumexp_{i} = m_{i} + log(s_{i})` which is enough for
backward pass.
3. For blocks where column indices are greater than row indices, which
occupy about half of the blocks in large sequences, computation is
skipped. It leads to a 1.7-1.8X speedup compared to those without
this skip.
4. Only use the row-wise :math:`logsumexp` instead of both the row-wise
max :math:`m(x)` and row-wise sum :math:`s(x)` of exponentials in the
softmax.
For parallelism, In the original version of FlashAttention, parallel
processing was done over the batch size and number of heads, with one
thread block processing one attention head. There are as many thread
blocks as the product of the batch size and the number of heads. This
works well on an A100 GPU, which has 108 Streaming Multiprocessors
(SMs), as long as the number of thread blocks is large enough to engage
most of the SMs, like 80 or more.
However, for long sequences, this isn’t as efficient because of the
smaller number of thread blocks. FlashAttention-2 introduces additional
parallelization over the sequence length dimension, which significantly
speeds up the process in these cases by improving GPU occupancy, i.e.
the fraction of GPU resources being used.
In the forward pass, the method schedules different parts of the
sequence length on different thread blocks that operate independently.
The backward pass also incorporates parallelization over the sequence
length. To update the gradients of the query matrix **dQ**, it uses
atomic additions to synchronize updates between different thread blocks.
Within each thread block, work partitioning for each wrap is also of
importance. Usually, 4 to 8 warps are allocated to each thread block. To
handle this condition, FlashAttention-2 introduces significant
improvements in both the forward and backward passes of the algorithm.
In the forward pass, unlike FlashAttention which splits **K** and **V**
across 4 warps (the "split-K" scheme) leading to inefficient shared
memory operations, FlashAttention-2 splits **Q** across the warps while
keeping **K** and **V** accessible to all. This change eliminates the
need for inter-warp communication and reduces shared memory
reads/writes, resulting in a faster runtime. Each warp directly
multiplies its slice of **Q** with **K** and then with **V**,
simplifying the computation of the output slice. In the backward pass,
FlashAttention-2 continues to avoid the "split-K" scheme, aligning the
warps in a way that minimizes shared memory operations. Despite
requiring some synchronization due to complex dependencies among inputs
and gradients, this approach still leads to a speedup by reducing the
shared memory reads/writes.
FlashAttention has gained significant attention in the industry for its
remarkable performance, offering accelerated attention computations in
both forward and backward passes while also reducing memory I/O
complexity. An enhanced version, FlashAttention-2, achieves a notable 2X
speedup over the standard FlashAttention [@dao2022flashattention].
Moreover, continuous optimization efforts are being made, promising an
even more potent version of FlashAttention in the future.