Previous 연구 정리
unipelt : a unified framework for parameter-efficient language model tuning 소개영상에서 가져옴. (UNIPELT 는 일단정리안함)
출처 : https://www.youtube.com/watch?v=Cmtvh_2MtGg&t=1612s
pelt : adapter, LoRA, prefix learning = language modeling 이 어떻게 효율적으로 fine-tuning 하는가에 대해 해결하기 위해 trainable parameter 수를 (많이) 줄이는 방법론.
PELT 는 위 3가지 방법 (Adapter, LoRA, Prefix Learning) 을 모두 사용하는 방법이다.
이 방법들은 항상 pre-trained weight 는 freeze 시킨다.
이는 학습 parameter 수를 매우 줄이는 방법이다
이 방법들에 대해 간단히 설명 후 LoRA 를 이야기 해보자.
1. Adapter
기존 layer 위에 Adapter layer 를 추가하고 adapter layer 만 학습한다.
이는 직렬 구조로, 단점이 될 수 있다 왜냐하면 multi head attention 결과값을 받아야 어댑터에서 연산을 할 수 있기 때문이다. 또한 inference latency 가 추가적으로 발생할 수 있다.
위의 왼쪽 첫번째 그림은 transformer 에서 adapter layer 가 어디에 위치하는 지에 대한 그림이다.
두번째 그림은 adapter layer 를 좀 더 자세히 표현한 그림이다.
feedforward down-projection layer 는 인풋 d 차원 을 m 차원(bottleneck representation) 으로 줄여주는 역할을 한다.
이후 feedforward up projection layer 를 통해 다시 output 을 인풋과 같은 차원으로 늘려준다.
학습 파라미터 수는 full-finetuning 의 경우 d^2 의 파라미터수 → 2md (m<<d) 만큼 줄어든다.
Final form : $$h \leftarrow h+ f(h W_{down}) W_{up}$$ (f=activation)
** LoRA 에 따르면 Inference latency 를 줄일 수 없는 문제가 있다.
2. Prefix learning
prompting (프롬프트 튜닝, p-tuning 등)가 motivation 이 되었다. ⇒ 즉, 앞에 붙여주는 prefix (접두사, prompt) 만으로도 학습을 할 수 있다는 개념.
prefix (input vector, attention key, query) 를 삽입하고 pre-fix 만 학습을 진행하는 방식. 이때 prefix 는 virtual token 으로 지정한다.
prefix 튜닝은 모델마다 prepend(:앞에 붙이다 라는 뜻) 하는 과정이 다를 수 있다.
GPT 와 같은 decoder model 은 prefix 를 하나만 붙여주지만, BART 와 같은 encoder-decoder 모델은 각 encoder 단에 prefix 를 하나 붙이고 decoder 단에 하나 붙여준다.
P_theta 가 결과적으로 학습해야하는 파라미터이다.
Final form : $$h \leftarrow (1-\lambda) h +\lambda f(xW_{1})W_{2}$$
얘도 parallel 한 병렬적 구조를 가진다.
** LoRA 에 따르면 prompt 를 최적화하는 것이 어렵고, prefix tuning 을 위해 (특히 모델이 가용할 수 있는 512, 1024 등의 sequence length 가 제한되어있는 모델들에게) 시퀀스 길이 일부를 홀드해야하기 때문에 max seq length 길이가 감소하는 단점이 있다.
LoRA : Low-Rank Adaptation of Large Language Models
마지막으로, 원래 소개하려했던 방법론이다.
Abstract
자연어처리의 중요한 패러다임은 general 도메인에 대한 큰 데이터셋 pre-train 과 특정 도메인에 대한 적응(fine-tuning) 으로 이루어져있다.
큰 모델(LLM)을 pre-train 하면서 모든 파라미터를 재교육하는 full fine-tuning 은 엄청난 비용이 든다.
이 논문에서는 Low-Rank Adaptation (LoRA) 를 이용하여 pre-trained 모델의 웨이트를 동결하고 학습이 가능한 rank decomposition 행렬을 트랜스포머의 각 layer 에 추가하여 downstream task 에 대해 학습 파라미터의 개수를 매우 줄이는 방법 을 제안한다.
Adam 으로 파인튜닝된 175B GPT-3 와 비교했을 때 학습가능한 파라미터 개수를 10,000배 줄일 수 있고 GPU 메모리 사용량 또한 3배 줄어든다.
또한 적은 파라미터 개수, 더 높은 학습 throughput (처리량), 추가적인 inference lateny 가 필요없음 에도 불구하고 RoBERTa, DeBERTa, GPT-2, and GPT-3 보다 좋거나 같은 성능을 얻을 수 있다.
1. Introduction
이 연구는 “학습된 over-parametrized (과하게 파라미터화된, 파라미터수가 커짐에 비해 효율적이지않은,,) 모델이 실제로는 low intrinsic dimension 이 존재한다는 것”을 보여주는 Li et al. (2018a); Aghajanyan et al. (2020) 연구에 영향을 받았다.
연구 내용은 다음과 같다.
- low intrinsic dimension : 보통 over-parameterized 모델에 많이 나타나는 특징.
- PLM (pre-trained LM) 은 low intrinsic(본질적인) dimension 을 가짐.
- 더 작은 subspace 에 대한 random projection 에도 불구하고 여전히 효율적으로 학습이 가능함을 보임 : 특히 RoBERTa 의 경우 200개 trainable 파라미터에 대해 90% 의 퍼포먼스를 달성했다고 함.
여기에 영감을 받아 이 논문에서는 adaptation (fine-tuninng 이라고 생각하면 편하다) 중의 weight 에 대한 업데이트가 낮은 ‘intrinsic rank ’ 를 가진다고 가정한다.
LoRA 는 adaptation 동안 pre-train 한 모델의 weight 는 고정시키면서 dense layer 변화의 rank decomposition 행렬을 대신 최적화함으로서 nn 의 몇몇 dense layer 를 간접적으로 학습시킨다. (그림1)
LoRA 의 이점
- pre-train 모델을 공유할 수 있고 다양한 task 에 대한 작은 LoRA 모듈을 만들 수 있다.
그림 1: 공유 모델을 고정시키고 행렬 A 와 B 를 대체하여 효과적으로 task 를 바꿀 수 있다. 그래서 저장 용량과 task 를 바꾸는 overhead 를 줄일 수 있다. - LoRA 는 adaptive optimizer 를 사용할 때 더 효율적으로 학습할 수 있고 하드웨어 조건을 3배까지 낮출 수 있다. 왜냐하면 모든 파라미터에 대해 gradient 를 계산할 필요가 없고 opitmizer 의 state 를 유지할 필요가 없기 때문이다.
- 그 대신 inject 된 더 작은 lower-rank 행렬을 최적화한다.
- LoRA 는 latency 를 줄일 수 있다.
- LoRA 는 다른 이전 방법(예를 들어, prefix-tuning 등)과 합쳐서 사용될 수 있다. Appendix E.
Terminologies and Conventions (용어와 규약)
$d_{model}$ : 트랜스포머의 인풋(768) 아웃풋(1024) dimension
$W_{q}, W_{k}, W_{v}, W_{o}$ : self attention 모듈의 query, key, value, output projection 행렬
$W \ or \ W_{0}$ : pre-trained weight 행렬 (175B 전체모델 파라미터)
$\Delta W$ : adaptation 동안의 축적된 gradient
$r$ : LoRA 모듈의 rank
Adam optimizer 를 사용
트랜스포머 feed forward dimension 으로 $d_{ffn} = 4 \times d_{model}$ 사용
4. Our Method
LoRA 의 principle 은 딥러닝 모델의 어느 dense layer 에도 적용가능하지만, 여기서는 Transformer language models 의 특정 weight 에만 실험한다.
4.1. Low-Rank-Parameterized Update Matrices
nn 모델은 행렬곱셈을 하기 때문에 많은 dense layer 를 가진다.
보통 이 weight 행렬은 full-rank 를 가진다.
??왜? 어떻게 그렇지?
pre-trained weight $W_{0} \in \mathbb{R}^{d \times k}$ 일때,
$W_{0} + \Delta W = W_{0} + BA$ 식의 low rank decomposition 으로 표현하여 업데이트를 제한한다.
이 때, $B \in \mathbb{R}^{d \times r}$ 와 $A \in \mathbb{R}^{r \times k}$, 이때 rank 는 $r << min(d, k)$ 조건을 가진다.
학습동안 $W_{0}$ 은 동결(고정) 되고 gradient update 를 받지 않는다. 그 동안 A 와 B 는 학습 가능한 파라미터를 가진다.
$W_{0}$ 와 $\Delta W = BA$ 는 둘 다 같은 인풋과 곱해지고 그 각각의 output 벡터는 좌표별로 합산된다.
$h = W_{0} x$ 라는 식의 수정된 forward pass 식은 밑과 같다.
그림 1 은 이 논문의 reparameterization 의 식이다.
A 는 random 가우시안 initialization 초기화를 사용하고 B 는zero 초기화를 사용한다.
그래서 학습 초반에는 $\Delta W = BA$ 는 0 이다.
또한 $\Delta W x$ 를 $\frac{\alpha}{r}$ 로 scaling 하는데 여기서 $\alpha$ 는 r 의 상수이다.
Adam optimizer 를 사용할 때 초기화를 적절히 scaling 할 경우, $\alpha$ 는 lr 튜닝과 거의 비슷하게 튜닝한다.
결과적으로 우리는 단순히 $\alpha$ 를 첫번째 r 에 설정하고 튜닝하지 않는다.(?)
이러한 scaling 을 통해 r 을 변화시킬 때마다 hyperparameter 를 다시 튜닝할 필요가 줄어든다.
위 부분은 이해가 안되서 그냥 번역을 적어놓았다..
A Generalization of Full Fine-tuning
일반적으로 fine-tuning 을 통해 pre-trained 된 파라미터의 subset 을 training 한다.
LoRA 는 위처럼 weight 행렬에 대한 누적 gradient 업데이트를 하지않고 pre-train 된 weight 행렬의 LoRA rank r 을 설정하여 full-fine tuning 의 표현력을 어느 정도 복구할 수 있다.
즉, 우리는 학습 가능한 파라미터의 개수를 늘리면 LoRA 의 훈련은 원래 모델 훈련으로 수렴되지만, adapter 기반의 학습방법은 MLP 나 prefix based 방법으로 수렴하여 긴 입력 시퀀스를 처리할 수 없는 모델로 수렴된다. (???)
No Additional Inference Latency.
inference 생성을 시작하면 $W = W_0 + BA$ 를 명시적으로 계산 및 저장하여 평소처럼 추론을 수행할 수 있다.
$W_0$ 및 $BA$ 는 모두 $\mathbb{R}^{d \times k}$ 단위이다.
다른 다운스트림 작업으로 전환해야 할 경우, $BA$ 를 빼고 다른 $B'A'$ 를 추가하여 $W_0$ 을 복구할 수 있다.
그래서 fine-tuning 모델과 비교하여 inference latency 가 발생하지 않는다.
=> eval 을 할때 LoRA 에서는 PLM 은 그대로 두고, LoRA A 와 B 만 가져와서 weight 에 AB 계산값을 더해주면 eval 이 가능하다. (= fine tune 된 값이 적용됨)
↔ 그와 반대로 adapter layer 는 트랜스포머 레이어 중간중간에 넣어줫기 때문에 시퀀셜하게 계산이 증가한다.
4.2. Applying LoRA to Transformers
원칙적으로는 trainable 파라미터의 개수를 줄이기 위해 신경망의 모든 weight 행렬의 subset에 LoRA를 적용할 수 있다.
Transformer 아키텍처에서 self-attention 모듈에는 4개의 가중치 행렬$(W_{q}, W_{k}, W_{v}, W_{o})$ 이 있고 MLP 모듈에는 2개의 가중치 행렬이 있다.
multi-head 때문에 attn weighr 의 output dim이 일반적으로 attention head 로 나눠지지만 이 논문에서는 Wq (또는 Wk, Wv)를 차원 $d_{model} \times d_{model}$ 의 단일 행렬로 취급한다.
여기에서는 downstream task 에서 attention weight 에만 적용하고 MLP 모듈을 freeze 한다(즉, mlp 는 다운스트림 작업에서 훈련되지 않는다.)
Sec 7.1의 트랜스포머의 다양한 행렬 에 향후 작업에 적용하는 실험을 한다.
MLP layer, LayerNorm layer 와 bias 는 퓨처웍.
Practical Benefits and Limitation
가장 큰 이점은 메모리 및 storage 사용량 감소이다.
Adam과 함께 훈련된 대규모 Transformer의 경우 동결 파라미터에 대한 optimizer 상태를 저장할 필요가 없으므로 $r << d_{model}$ 일때 VRAM 사용량을 최대 2/3까지 줄일 수 있다.
GPT-3 175B에서는 training 중 VRAM 사용량을 1.2TB 에서 350GB 까지 줄일 수 있다.
6. Related Works
Low-Rank Structures in 딥러닝
머신러닝에서 low rank 구조는 매우 일반적이다. 많은 머신러닝 문제들이 특정한 intrinsic low-rank 구조를 가진다.
또한 많은 딥러닝 task 에서, 특히 뉴럴넷이 over-parameterized 된 상황에서 학습된 뉴럴넷은 training 이후 low-rank 특성을 이용하는 것으로 알려져 있다.(?? 아무래도 관련 논문 다 봐야할 듯…)
일부 이전 연구들은 오리지널 뉴럴넷을 학습할 때 low-rank constraint(조건) 을 명시적으로 부과하기도 한다.
하지만 (이 논문이 아는 한) 이러한 연구 중 아무도 downstream task 에서 weight를 고정시킨 모델에 대해 low-rank update 를 고려하지 않는다.
이하 논문 이외
구현적으로도 단순하다 : torch.nn.linear 를 상속받아서 새로운 linear를 만듦.
LoRA 를 좀 더 간단히 정리
: 기존의 파인튜닝은 (파란색 네모) pre-trained weight 를 계속 업데이트 하는 방식을 사용한다.
하지만 LoRA 는 기존 pre-trained weight 는 freeze 를 시켜놓고 새로 추가한 LoRA 모듈(decomposed matrix) 만 학습시켜 나중에 더해줘서 최종 모델을 산출하는 방식이다.
full finetuning 의 경우 d^2 의 학습 파라미터 수 → 2md (m<<d) 만큼 줄어든다.
Final form 은 $h \leftarrow h + s \cdot xW_{down}W_{up}$
위 식에서 s 는 scaling factor 이다.
** adapter 와 LoRA 의 차이
adapter 는 직렬연결이므로 위로 쌓고, 순차적 학습만 가능, LoRA 는 병렬연결로, 연산을 나눠서 병렬학습이 가능한 장점. 그래서 계산효율성이 더 좋음.
** fully fine tuning 식과 달라진 부분은
\Phi = pre-train 전체모델 파라미터
\Theta : 모듈 파라미터
변화량 \Delta 을 rank decomposition matrix 로 최적화한다.
'LLM 관련 논문 정리' 카테고리의 다른 글
Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning (NIPS, 2208) (0) | 2023.09.30 |
---|---|
PEFT (parameter-efficient fine tuning) 정리 (0) | 2023.09.18 |
InstructGPT 상세 리뷰 (0) | 2023.08.30 |
Rotary Position Embedding (RoPE) (0) | 2023.08.21 |
NTK-aware dynamic interpolation (0) | 2023.08.21 |