https://cameronrwolfe.substack.com/p/llama-2-from-the-ground-up
LLaMA-2 from the Ground Up
Everything you need to know about the best open-source LLM on the market...
cameronrwolfe.substack.com
블로그가 Llama-2 모델 설명이 매우 잘되있어서 정리한다.
LLaMA Model 구조
LLaMa-1 와의 차이점
LLAMA-1 에 비하면 더 많은 데이터(2 trillion tokens, 40% 더 늘어)로 pre-training 햇고 더 긴 context length (2k -> 4k, 4096) 를 가지고 더 빠른 inference 를 할 수 있도록 하는 구조(GQA) 를 가진다.
아래 그림은 LLaMa 모델과 기존 트랜스포머 디코더 모델의 구조적 차이를 잘 설명한 블로그 그림이다!
기존 decoder-only 모델 (gpt) 은 Attention is all you need 의 Transformer 에서의 encoder-decoder cross attention 모듈만을 뺀 구조인데, llama 는 아예 Normalization 을 Attn 과 FFN 앞에 뒀다.
각 요소를 하나하나 살펴보자.
RMS Normalization
그리고 Layer Normalization 이 아닌 RMS(Root Mean Square) Layer Norm 을 사용했다.
🪐 Normalizati Normalization 이란? : 각 feature 끼리 값의 분포 범위를 비슷하게 만드는 작업이다. Layer normalization 은 data sample 마다 가진 feature 개수로 나누어 평균과 분산을 구한다. batch size 가 작아도 효과적으로 사용할 수 있다.
SwiGLU
또한 LLAMA 모델은 일반적으로 사용했던 ReLU 함수가 아닌 SwiGLU activation function 을 FFN 에 사용했다.
위의 식과 같이, input x 를 weight W 와 V 로 element-wise product 한 후 Swish activation 을 적용한다.
3개의 matrix multiplication 이 필요로 하기 때문에 기존의 ReLU 와 같은 activation 보다는 더 computationally expansive 하다. 하지만 성능이 더 좋다고.
RoPE
Positional embedding 에 대해서도 일반 sign 과 cos 으로 계산된 Absolute 나 Relative Posituonal embedding 이 아닌 새로운 방법을 사용하였다.
RoPE (Rotary Positional Embedding) 가 그것이다.
회전 행렬을 사용하여 relative 정보를 주고 때문에 긴 context length 를 사용할 수 있게 되었다.
블로그 참조 : https://mari970.tistory.com/49
Rotary Position Embedding (RoPE)
Roformer: Enhanced Transformer with Rotary Positon Embedding https://arxiv.org/pdf/2104.09864.pdf 에서 처음 제안된 방법이다. Abstract 기존의 PE 는 트랜스포머에서 제안된 방법으로, 시퀀스 내의 토큰을 attention 만으로
mari970.tistory.com
Grouped Query Attention
GQA(Grouped Query Attention) 를 사용하여 inference 시간을 줄일 수 있었다
N개의 self attn head 를 m 개의 group 으로 나눈다.
가 그룹에 대해 m개의 key와 value 만을 공유해서 attention 을 계산한다.
(llama 모델이나 mistral 에 따라서 사용하기도 하고 안하기도 한다!)
Train 방법
위 까지는 구조적 차이이고, 이제 학습 방법론적 차이를 보자.
학습 방법은 모델마다 비슷하기 때문에(unlabeled data를 이용한 next token prediction objective 사용) 모델의 성능에 매우 영향을 미치는 것은 데이터의 양과 질이다. 특히 양보다 질이 중요하다는 연구 논문도 있었다. (무슨논문이더라.. 아마 “Lima: Less is more for alignment.”?)
Llama 와 Llama2 는 둘 다 pre-training 에 public data 를 사용했다. 이 두 모델의 차이점은 Llama2 에서는 여기에 더 좋은 퀄리티의 데이터셋을 추가하여 기존 데이터셋 양의 40% 늘렷다는 점이다.
Llama2 - chat 의 fine-tuning 방법
Supervised Fine Tuning(SFT) 를 한 후 PPO 를 이용한 RLHF 방법을 사용한다. 이는 Instructed GPT 에서 사용한 방법과 같다.
블로그 참조 : https://mari970.tistory.com/50
InstructGPT 상세 리뷰
Language Model 을 크게 만든다고 해서 user 의 의도를 더 잘 따르는 것은 아니다. LM의 안좋은 output 에는 1. untruthful 2. toxic 3. not helpful 이 있다. 이 논문에서는 human 피드백을 이용한 fine-tuning 을 통해 다
mari970.tistory.com
Alignment 란?
사용자의 목적에 부합하는 결과물을 생성하도록 학습시키는 방식이다.
Supervised Fine-Tuning
SFT는 다음 토큰 예측(Next Token Prediction) objective 를 사용하여 프롬프트-response 쌍 데이터셋을 학습한다.
LLaMA-2-Chat은 SFT 도 2단계를 , 첫번째 단계에서는 더 많은 public 데이터셋으로 학습하고, 2단계에서는 더 높은 품질을 데이터셋을 사용한다.
Reinforcement Learning from Human Feedback
🪐 RLHF : 사람 annotator 가 피드백한 (어떤 답변이 더 좋은 답변인지 ranking 해놓은) 데이터를 이용하여 reward 모델을 학습한다. 이 리워드 모델로 LM 모델이 생성한 답변에 대해 좋은 답변은 높은 리워드를 주고 나쁜 답변에는 낮은 리워드를 주어 학습시키는 강화학습의 방법.
인간 선호도의 요소로는 Helpness, Safety 등이 있다.
RLHF 를 위해서 Step 2 에서는 사람 annotator 가 만든 prompt 에 대해 (sft된) LLM 은 여러 개의 답변을 생성하고 이 답변에 대해서 다시 사람이 답변 퀄리티에 따른 ranking 을 매긴다.
Step 3 에서는 Step 2 에서 학습은 Reward 모델을 이용하여 PPO 라는 알고리즘을 통해 SFT 된 모델을 강화학습 시킨다.
** LLaMA 2 에서의 safety-based data 모으기.
를 위해서는 2가지 기술을 사용한다.
- Risk Categories: LLM 이 unsafe 한 내용을 생성할 수 있는 topic
- Attack Vectors: negative behavior 를 도출할 수 있는 다양한 프롬프트의 질문 스타일
1번에 대해서는 illicit and criminal activities, hateful and harmful activities, unqualified advice 3 가지 카테고리를 포함한다.
2번에 대해서는 psychological manipulation, logic manipulation 과 같은 attack 벡터들을 정의한다.
이러한 기준을 바탕으로 사람 annotator 들이 프롬프트를 만들어서 이런 안좋은 behavior 를 억제하도록 한다.
RLHF 의 과정은 아래에서 좀 더 자세히 알아보자.
Training the reward model.
리워드 모델은 LLM 과 같은 모델 아키텍처와 weight size 를 가지는 모델이 사용되지만 기존 모델의 NTP 을 위한 classification head 가 regression head 로 대체하여 preference 를 fine-tuning 시킨다. → regression 으로 값을 1개 생성한다.
아래 그림과 같다.
reward 모델은 prompt-response 쌍에 대해 preference score 를 구하여 강화학습에서 LLM 을 파인튜닝하는데에 사용한다. human preference 점수를 maximize 하도록 리워드모델은 자동으로 preference 점수를 생성한다.
리워드 모델을 학습시키기 위해서는 binary preference data 를 사용하고 prefered 데이터를 (counter part 데이터보다) 더 높은 점수를 얻도록 하는 로스 함수를 만든다.
objetive 식은 아래와 같다.
위 식에서 x 는 input seq 즉, 프롬프트 이고, y 는 그에 대한 responce 이다.
실제로 LLaMA 모델에서 사용하는 선호도 데이터셋은 단순한 binary 가 아니다. annotator 는 significantly better, better, slightly better, or negligibly better 로 더 세분화하여 라벨링한다.
그러므로 이런 더 detail 한 분류를 학습하기위해 loss 에 margin을 추가한다. 선호도 차이가 큰것에 대햐 더 큰 마진을 줌으로서 학습할 수 있다.
** Alignment 의 tradeoff
각 helpful 한 기능과 safety한 기능을 학습시키기 위해 서로 다른 리워드모델을 학습시킨다.
Helpful 과 safety 사이에는 트레이드 오프가있다. (어떤 alignment 기준에 대해서든 그렇다) 즉, 모델이 safety 하다면 helpful 하지 않을 수 있다.
→ 그래서 리워드 모델을 따로 사용한다!
Optimization via RL.
LLaMA-2 에서는 2가지 RLHF 알고리즘을 사용한다.
- PPO : 기본적인 RLHF 알고리즘
- Rejection Sampling : 각 프롬프트에 대한 k 개의 LLM response 를 샘플링하여 Reward 모델로 각 response 를 점수매기고 가장 best responce 를 뽑아서 이 예제를 가지고 fine-tuning 을 진행한다.
🪐 PPO(Proximal Policy Optimization) : Deeplearning 모델을 policy 로 하여 리워드를 이용해 업데이트하는 방식이다. Policy 의 과도한 업데이트를 막기위해 KL div 나 Clipping을 사용한다.
💡 Rejection Sampling Fine-Tuning : rejection sampling 이란 샘플을 쉬운 reference 분포에서 샘플을 생성한 후 복잡한 target 분포와 비교하여 적합하지 않은 샘플들은 빼서(rejection) 최종 샘플을 정하는 방법이다.
즉, LM 이 N 개의 답변을 생성햇을때 각 답변에 대한 리워드를 계산하여 특정 값 이상의 답변이나 제일 좋은 값의 답변만 채택한다. 이 답변을 label 로 모델을 파인튜닝 시키는 방법이다.
Ref : https://tech.scatterlab.co.kr/alt-rlhf/
PPO 는 1번 iteration 에 1개의 샘플만 사용하지만 RS 은 여러 개의 샘플을 필요로 한다. RS 는 생성한 답변 중 가장 높은 리워드를 가진 답변을 새로운 “gold standard” (즉, 타겟) 으로 간주한다.
이러한 방식으로 여러 답변을 생성하면 파인튜닝 동안 데이터에서 확인할 수 있는 최대 리워드값이 매우 증가한다.
LLaMA-2 에서 다른 작은 모델들을 RLHF 할 때 Rejection Sampling 에서 사용하는 모델은 LLaMA-70B-Chat 이다. 즉, 70B-chat 을 이용하여 데이터셋을 만들어낸다는 뜻이다.
PPO 는 각 샘플 이후에 update 를 하지만, Rejection Sampling Fine-tuning 은 RLHF 하기 전의 initial model 을 사용하여 높은 reward 샘플의 데이터셋을 생성한다. 이 데이터셋으로 파인튜닝(update)한다.
라마2는 rejection sampling 할 때 RLHF 의 모든 단계의 모델에서 Best sample 을 생성한다.
LLaMA-2 에서는 RS 파인튜닝 이후 PPO 를 적용한다.
** Tweaking the temperature
RLHF 에서 답변을 생성할 때 모델이 업데이트 될 때 마다 optimal temperature 가 변하는 것을 확인하였다.
그래서 RLHF 의 각 단계마다 temperature 를 다시 조정해야 한다.
그리고 또한 temperature 는 프롬프트 context 에 의존한다. 예를 들면, creative 나 factual 한 프롬프트와 같은 특정 타입의 프롬프트에서 에서 optimal temperature 변화를 확인하였다.
Ghost Attention for multi-turn dialogue
대화 agent 는 사용자가 Multi-turn 대화를 사용할 수 있도록 해야한다. 하지만 LLM 은 이전 대화를 잘 잊기 때문에 구현하는 것이 쉽지 않다. 그래서 초기 턴에만 적용될 수 있고 그랬다.
이를 해결하기 위해 Ghost Attention (GAtt) 를 사용하여 fine-tuning 하였다.
대화 세션이 주어지면, GAtt 는
- 대화에서 나온 instruction 을 샘플링하고
- 이 instruction 을 모든 사용자의 메세지 앞에 concat 시키고
- RLHF 된 모델을 이용해 각 메세지에 대한 responce 를 샘플링한다.
첫 번째 유저 메세지를 제외한 모든 대화에 붙은 inst 를 제거한 후
이 데이터로 multiturn 데이터셋을 만들 수 있다.
이 데이터셋을 이용해 SFT 학습시켜 멀티턴 학습을 시킬 수 있다.
LLaMA-2 (Base Model) Performance
open-source LLM 중에서는 가장 좋은 성능을 얻었다.
또한 오픈되지 않은 GPT-3.5 와 같은 모델과 비교할 때에는 성능이 떨어지지만 나름 comparable 하다.
Conclusion
- Garbage in = Garbage out : Dataset 의 품질이 매우 중요하다.
- Alignment 는 중요하다.
- RLHF 는 강력하다. : 매우 효과가 있다.
'LLM 관련 논문 정리' 카테고리의 다른 글
RAGAS: Automated Evaluation of Retrieval Augmented Generation (1) | 2024.11.10 |
---|---|
DoRA: Weight-Decomposed Low Rank Adaptation (0) | 2024.05.02 |
SOLAR model paper (1) | 2024.01.13 |
SPoT: Better Frozen Model Adaptation through Soft Prompt Transfer (0) | 2023.11.23 |
NEFTune: Noisy Embeddings Improve Instruction Fine-tuning (0) | 2023.11.15 |