Skip to content

[223] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention #246

@lianam-tl

Description

@lianam-tl
Image

paper, blog, code

TL;DR

  • I read this because.. : 4월은 linear transformer -- 첫번째 논문
  • task : autoregressive sequence modeling, language modeling, machine translation
  • problem : self-attention이 모든 토큰 쌍을 비교해서 시간/메모리 O(N²), 긴 시퀀스에서 비효율적
  • idea : softmax attention을 kernel 형태 $\phi(Q)\phi(K)^T$로 바꿔서 결합법칙으로 재배열, 이를 통해 cumulative sum을 함
  • input/output : token -> token
  • architecture : softmax 대신 kernel with elu function으로 바꿈. 그 외 아키텍쳐 상 변경점은 없음
  • objective : CE loss
  • baseline : Transformer, RoFormer
  • data : WMT, language modeling benchmark
  • evaluation : BLEU (MT), perplexity (LM)
  • result : 긴 시퀀스에서 큰 속도/메모리 개선, 성능은 약간 감소하거나 유사한 수준
  • contribution : attention을 kernel로 재해석, O(N) linear attention 제안, transformer가 RNN처럼 동작함을 보임
  • etc. : causal masking이 prefix sum 구조에 자연스럽게 포함됨, layerwise parallelism 가능

Details

  • conversation with chatGPT: link
  • X$

3.1 Transformer

Image
  • $f_l(.)$은 그냥 FFN
  • $A_l(.)$ self-attnetion
Image

저기서 softmax term을 그냥 유사도 함수 $sim(\cdot)$로 표현할 수 있음

Image

3.2. Linearized attention

여기가 갑자기 헷갈리는데, Kernel Trick이란 걸 쓸거임.
attention에서 $sim(\cdot)$은 "non-negative"여야 한다는 제약 밖에 없음
그렇다면 모든 kernel 중에 , $k(x,y) : \mathbb{R}^{2 \times F} -> \mathbb{R}_{+}$ 를 포함할 수 있게 됨

그런 "imaginery kernel"($k$)이 있다 치고, feature 표현 $\phi(x)$에 대해 eq (2)를 다시 쓰면

Image

위에서 $\sum _j$는 j에 대한 값이기 때문에 $\phi(Q_i)^T$를 넘길 수 있고, 그러면 아래와 같이 식이 됨
Image

이 때 feature map $\phi(\cdot)$$Q$, $K$ 행렬에 row-wise로 연산됨
eq (6)의 괄호 안은 $\phi(X)^T\in \mathbb{R}^{D\times N}$, $\phi(X)^T\in \mathbb{R}^{N\times D}$ 이어서 $O(N)$의 시간, 공간 복잡도를 가지게 됨. (공간 복잡도와 시간 복잡도가 헷갈리네..) --
그 이유는 우리가 KV, K를 한번 저장하고 재사용할 것이기 때문.
Image

Feature maps and computational cost

Kernel을 어떤 것을 사용하냐에 따라 computational cost가 달라지기 때문에 elu 함수를 선택
Image

relu over elu를 사용한 것은 0 이하일 때도 gradient가 흘렀으면 좋겠어서

3.3 Causal Masking

Transformer의 Causal masking을 여기선 어떻게 구할 수 있냐
이것은 summation을 모든 j에 대해 하는게 아니라 $i$까지 하도록 바꾸면 됨

(이전의 식)

Image

(w/ causal masking)
Image

우리는 $S_{i-1}$로 부터 $S_{i}$를 계산할 수 있음. 왜냐하면 누적합이기 때문에.
여기서 처음 읽을 때 헷갈렸는데, Inference 시에 누적합으로 한다는 것이고 실제 학습 때는 원래의 transformer처럼 causal mask를 적용.

3.3.1 Gradient Computation

gradient를 나이브하게 구하면 또 $O(N^2)$ 복잡도가 되지만 잘 구해서 얘도 Linear하게 함

Image

3.3.2 Training and Inference

Transformer 대비 좋은 점은 Inference 시에 QK를 안가지고 있어도 되어서 메모리가 seq len에 비례하여 늘어나지 않음. 즉 train, inference의 좋은 점을 다 가져옴
Image

3.4. Transformers are RNNs

Image

Experiment

스킵 ㅎ

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions