** Apple 에서 나온 논문이다
https://arxiv.org/pdf/2405.13226
원래 decoder-only 모델을 사용할 때는 1개 문장을 학습할 때 max seq length 만큼 padding 하여 학습하는 방법을 사용한다.
하지만 padding 은 학습 비효율성을 매우매우 늘리기 때문에 특히나 LLM 을 학습할 때에는 불리하다.
그래서 요즘은 "packing" 이라는 방법을 사용한다.
packing 이란 padding 을 하지 않기 위해 다음 문장을 padding 부분에 붙여서 최대한 많은 문장을 학습하는 방법이다.
이 논문은 이러한 packing 을 조금 더 개선한 방법이다.
Abstract
기존 LLM 이 고정된 seq seq length 를 이용하여 학습하기 위해 다양한 길이의 문장을 랜덤하게 concat 해서 chuncking 하는 방식을 사용한다.
이러한 방식은 seq 안에서 cross-document attention (=서로 다른 문서에 대해 1개 sequence 안에서 청킹으로 concat 되어서 attention 계산되는 것) 을 일으킬 수 있는데, 이는 바람직한 learning signal 이 아니고(?) 계산적으로 효율적이지도 않다.
또한 quadratic 한 계산 비용 증가로, 매우 비싸다.
그래서 이 연구에서는 다양한 길이의 seq len 을 학습하는 새로운 기술인 dataset decomposition 을 제안한다.
dataset 을 서로 다른 doc 에서 나온 같은 길이의 시퀀스로 이루어지는 bucket 의 조합으로 분해한다.
학습 동안 다양한 seq len 랑 batch size 를 사용하여 커리큘럼에 따라 모든 bucket 에서 동시에 샘플링한다.
(모든 학습 step 에서 고정된 attention cost 를 가지는) concat-and-chunk baseline 과 비교하여, 우리 방법은 각 step 에서 실제 document 길이에 비례한 penalty 를 발생시킨다. 그래서 학습에서 매우 큰 saving 을 얻는다.
concat-and-chunck 방법으로 학습한 2k context-length model 과 같은 계산비용으로 8k context-length 1B model 을 학습한다.
web-scale corpus 데이터셋 실험에 대해서는 이 논문에서 제안한 방식이 standard language evaluations 와 long-context benchmarks에서 baseline 에 비해 3배 빠르게 정확도를 얻었다.
긴 시퀀스에서 효과적으로 pre-train 할 뿐 아니라, dataset 크기에 대해 효과적으로 확장할 수 있다.
또한 sequence length 에 대한 분포와 커리큘럼 또한 조명한다.
Introduction
웹(아마 크롤링한 데이터셋이라는 뜻..) 데이터셋 : Pile [19], RefinedWeb [42], RedPajama [14], DOLMA [53] 등.
여러 documents 로 이루어져있다 : Wikipedia articles, books , code repositories.
documents 의 각각 길이는 짧은 메세지 부터 매우 긴 책까지 다양하다. 하지만 모델 학습은 1개 batch 당 한정적인 length 만을 지원한다. 효율적인 학습을 위해 chunking 을 많이 사용한다.
이 논문에서는 chuncking 의 효과를 확인하고 대안을 제시하고 평가한다. Recent works [39, 34, 55, 56] (=fairseq논문, Roberta 논문, Llama, Llama2 논문) 에서는 concat-and-chunk approach 를 많이 사용한다.
이 방법은 (학습전) data preparation 단계에서
우선 random shuffle 하고 tokenized documents 를 다 concatenate 한다 .
연속되는 doc 들은 special token <EOT>, 으로 분리되어 doc 간의 범위를 인식할 수 있도록 한다.
그 후 target sequence length 단위로 subsequences 를 나눈다(chunk).
예를 들어 각각 Llama-1 는 max seq len = 2048 와 Llama-2 는 4096을 가지는데, 이 길이만큼 모델이 pretrained 된다.
concat-and-chunk 방법의 단점
- 랜덤하게 doc 을 concat 하는 것은 모델이 서로 연관 없는 문서에 대해 attend 하여 다음 새로운 토큰을 예측하도록 할 수 있다. 잘 학습된 모델은 cross-document attention를 피하도록 학습할 수 있지만 (?어떻게?) 명백하게 적용되지는 않으므로 잠재적인 가짜 모델링이 필요하다.
- cross-document attention 는 서로 관계없는 토큰 간의 필요 없는 계산을 소비한다. 어텐션은 quadratic complexity 를 가지므로 이는 치명적이다.
- document 가 target sequence length보다 짧아도 2개 청크로 나뉠 수 있다. (2개 sequence 경계에 있을 때) 이는 평균 문서 길이에 비해 평균 청크 길이가 작아져서 모델 성능이 저하될 수 있다. (Fig. 3a)
그래서 LLM 학습 연구 들에서는 concat-and-chunk approach 를 발전시킨 방법들을 내고 있다:
- document-masking [36](llama 3) : cross-document attention 해결하기 위함
- best-fit packing [17](Fewer truncations improve language modeling) : document chunking를 줄이고 random 대신 관련있는 doc 를 concat 하기 위해 [51](In-context pretraining: Language modeling beyond document boundaries).
하지만 위 어떤 연구도 이 3가지 문제를 함께 언급하지는 않는다. (= 이 논문에서는 다 다루겠다.)
이 논문에서 주장하는 DD 는 concatand-chunk 보다 장점이 있다.
- DD 는 간단하고 data preparation 단계 동안 무시할 수 있는 계산 overhead 가 생긴다. 이를 통해 큰 dataset 에도 적용할 수 있다.
- 모든 bucket 의 각 시퀀스에 있는 토큰은 동일한 문서에서 가져온 것을 보장하기 때문에 cross document attention을 피할 수 있다.
- sequence length distribution 라는 보조 prior knowledge 에 접근하여 학습을 위해 다른 mixtures/curricula 를 사용할 수도 있다.
- VSL 학습을 이용하여 학습 시간을 단축할 수 있다. : attention 연산의 quadratic complexity 때문에 1개 최적화 step 의 latency 는 i 가 더 작은 D_i 에서 샘플링할 때 더 작다. (?)
contribution 은
- DD 소개, DD 의 장점
- 다양한 모델, dataset 을 이용해 eval 함 : 데이터 효율성(> 2배)과 계산 효율성(11%~45%)
- 다른 언어와 long context 에 대한 학습(pre-training) 에서의 sequence length distribution 과 mixture 중요성을 실험으로 보인다. seq 길이를 합성적으로 변경하기 위한 concatenation and chunking 연산의 효과를 보인다.
Method
2.1 Dataset decomposition
D(dataset) = {d1, d2, . . . , dn} 이 때 d= tokenized document 가 주어지면
dataset decomposition (DD) 는 D 를 bucket 의 집합 $∪_iD_i$ 으로 재생성한다. 이때
- 각 bucket D_i 는 l_i 길이의 token 의 seq 으로 이루어진다.
- $s ∈ D_i$ 인 각 시퀀스는 D 에 포함되는 1개 document 의 subsequence 이다.
(s는 문장 단위가 아니다!) - D 에 있는 각 token 은 정확히 1개의 D_i 안에 있다. (중복이 없다는 뜻인가,,)
(3) each token in D appears in exactly one Di
이 decomposition 은 각 unique document 에 들어있는 각 시퀀스 들을 만들어내(그러니까 같은 document 에 포함되는 문장들을 뭉친다는 것인가) 학습 동안 시퀀스 안에서 cross-docuent attention 이 없도록 한다.
게다가 주어진 bucket Di 안의 모든 시퀀스 들은 같은 길이 l_i 를 가지고 있어서 batching 하기 좋다.
그래서 이 논문에서는 구체적인 decomposition 을 제안한다.
이는 $l_i = 2^i$ 를 사용하여 효율적인 batch pre-training 을 하면서 원래 오리지날의 문서 sequence length distribution을 최적으로 유지하기 위함이다. (2.2 에서 설명한다.)
decomposition 은 document 레벨로 적용하여 기존의 data preparation pipeline 과 잘 결합하여 사용할 수 있도록 한다. 또한 큰 dataset 에도 확장이 쉽다.
$l = 2^{i_1} + 2^{i_2} + . . . + 2^{i_k}$ 는 길이 l 의 binary decomposition 을 진행한다.
tokenized document d 를 k개의 인접한 sequences $s_1, . . . , s_k$ 로 나눈다. 이 때 길이는 각각 $2^{i_1}, . . . , 2^{i_k}$ 이다.
길이가 $2^{i_j}$ 인 $s_j$ 시퀀스는 $D_{i_j}$ 에 배정된다. (그림2)
각 bucket D_i 는 오리지널 doc d 의 길이가 최소 $2^i$ 가 되도록 하는 d 에서 추출한 시퀀스들을 포함한다. (오리지널 doc의 길이를 기준, 자를 길이보다 큰 오리지널 문서를 사용한다는 뜻인듯하다.)
Fig. 3b 는 RefinedWeb dataset tokens의 분포를 각 bucket 마다 보여준다.
D_9 는 길이 512 가 max length 를 가진다.
이 논문에서는 또한 토큰을 추출하는 오리지널 doc 의 길이 강조한다.
bucket D_i 의 대부분의 토큰들은 doc 길이 l 이 $2^i ≤ l < 2^{i+1}$ 이 되도록 하는 doc 에서 추출하고, 나머지는 $l ≥ 2^{i+1}$ 인 doc 로 이월된다. (=roll over)
이는 많이 없는 긴 문서에서 특히 오리지널 문서의 길이를 유지하는 효율을 보여준다.
Fig. 3a 에서 오리지널 doc 과 2048, 8192 seq_len 로 concat-and-chunk 된 chunk 의 분포를 나타낸다.
이 논문에서는 또한 최근 연구인 bin-packing approximate 알고리즘의 길이 분포에 대해서도 나타낸다.
또한 Fig. 3c 에서는 DD 를 사용할 때와 seq_len 8192 baseline 에서의 context length (pre-training 동안 1개 토큰이 (같은 문서 안에서!) attend 할 수 있는 토큰 개수)의 분포 를 확인할 수 있다. (Appendix F)
정적인 dataset 을 출력하는 concat-and-chunk 방법과 다르게 DD 는 prior knowledge 에 따라 seq len 분포를 사용할 수 있고, target task 에 따라 best mixture 를 최적화할 수 있다.
2.2 Variable sequence length training
$2^i$ 길이를 가지는 시퀀스를 포함하는 D_i bucket 이 K 개 있다고 가정한다.
b = target batch size (optimization step 마다 사용되는 token 개수)
variable sequence length (VSL) training 은 모든 최적화 step 에서 가능한 범위 내의 i 값을 먼저 샘플링하고 그 후 D_i 에서 $b/2^i$ 개의 시퀀스를 고른다.
D_i 는 2^i 길이의 seq 로 이루어져 있기 때문에 최적화 step 마다 보이는 token 개수는 b 이고 i 의 선택에 독립적이다.
⇒ 즉, 가변적인 seq len 을 사용한다는 것은 가변적인 batch size 를 사용한다는 것이다. 이거 한마디면 모든 내용이 이해가 간다!
LLM 을 vsl 로 학습시키는 것의 장점
- 최적화 step 마다 보이는 token 개수 는 변하지 않기 때문에 최적화 방법을 바꿀 필요 없고 baseline 과 같은 hyper-parameters 를 사용할 수 있다.
- Section 3.1 처럼 , 고정된 b 사이즈에 대한 (forward+backward)의 1개 최적화 step 을 완료하는 시간이 attention 연산(quadratic cost )때문에 시퀀스 길이에 따라 변한다는 것을 보인다.
- VSL training 에서는 모든 최적화 step 에서의 비용은 해당 step 의 sequence length 에 대해 샘플링된 bucket D_i 에 따라 달라진다. 따라서 더 긴 시퀀스의 step 은 더 짧은 시퀀스 step에서 보상받을 수 있다.
- VSL 의 sampling component (=모든 최적화 step 에서 어떤 D_i 가 선택될지) 는 서로 다른 sequnce length 의 커리큘럼이 가능하다. Section 3.4에서는 그런 커리큘럼의 중요성과 그에 따라 모델 안정성과 일반화 정확도가 어떻게 달라지는 지를 보인다.
** 그러면 커리큘럼이 무엇인가
<aside> 💡 “커리큘럼”이란 뭔가 : Yoshua Bengio의 Curriculum learning 논문.
학습 과정동안 의미있는 순서대로(더 복잡한 순서로) 학습을 시키면 더 효과적이라는 것. 그래서 쉬운 데이터로 학습을 시작하여 점점 전체 데이터를 사용하는 방법이다. → 즉, 학습의 순서에 대한 이야기.
</aside>
3. Experiments and analysis
Section 3.5 를 제외한 모든 실험에서 EleutherAI/gpt neox [9] tokenizer (vocabulary size is 50,432) 사용.
모델과 학습코드는 OpenLM 사용. (OpenLM-1B 8k 모델 사용)
RoPE 사용. f=10,000
[44, 60, 33](Yarn, Effective long-context scaling of foundation models, Scaling laws of rope-based extrapolation) 에서는 frequency 를 늘리고 fine-tuning 하면 좋다는 결과가 있는데, 이 실험에서는 pre-train 부터 늘려서 사용했을 때도 좋다는 것을 확인했다. f=100,000 (Table 4)
Evaluation (벤치마크) : LLM Foundry
Task
Commonsense Reasoning (CSR) | PIQA-0-shot, COPA-0-shot , OpenBookQA-10-shots |
Language Understanding (LU) | Lambada-OpenAI , Hellaswag-0-shot , Winograd-3-shots , WinoGrande-5-shots |
Reading Comprehension (RC) | SQuAD-3-shots [46], BoolQ-0-shot [12], and CoQA-0-shot |
World Knowledge (WK) | Jeopardy-3-shots [3], ArcEasy-3-shots [13], ArcChallenge-3- |
shots [13], and WikiDataQA-3-shots | |
Multi-Document Question Answering (MDQA) | MDQA-10, MDQA-20, and MDQA-30 |
TOEFL | - |
QuALITY | - |
** 아래 흰색은 longer context 를 위한 eval set
3.1 Training efficiency
VSL 학습이 concat-and-chunk 보ㄷ 더 높은 throughput 을 가지는지 확인한다.
모델 사이즈 (OpenLM-1B/3B/7B)와 context length (2^6 to 2^13) 에 따른 100 bach (size = 8 × 8192)학습에서의 학습 시간을 따진다. : 8 GPUs single node (Appendix C.1)
3.4 Length-based curriculum
보통 더 짧은 시퀀스가 긴 시퀀스보다 쉽다 고 생각할 수 있다.
VSL 에서도 샘플링 디자인을 통해 curriculum learning 을 쉽게 구현할 수 있다.
매 최적화 step 에서 bucket D_i 에서 확률 p_i 로 샘플링하되, b 개 토큰으로 batch 를 대체하지는 않는다. (??뭔소리)
bucket 이 비어있다면 샘플링에서 제외한다.
이 논문에서는 같은 개수의 토큰이 들어있는 D8, . . . , D13 bucket 에서 "≥ 256" mixture 에 대한 커리큘럼을 연구한다.
각 커리큘럼에서 각 bucket 에서 batch 를 샘플링할 확률(normalized 된 p_i)을 결정한다.
커리큘럼은 short 부터 long 시퀀스까지 변화하는 pace 인데, 이 때 p_i 는 linear 하거나, 2의 거듭제곱으로변화하거나, bucket 사이에서는 100 의 거듭제곱으로 변화하는 것 3가지를 고려한다.
커리큘럼은 implicit bias 를 일으킬수도 있다. 예를 들면 학습 끝에 긴 시퀀스만 보게되면 learning rate 가 작을 때만 긴 시퀀스 학습이 발생할 수 있다.
이런 문제를 해결하기위해 커리큘럼이 주기적으로 적용되는 cyclic curricula 도 확인한다.
"Grow-P2"(아마 2의 거듭제곱 방법) curriculum이 가장 최적임을 보인다.
추가적인 장점은 훈련 안정성이다.
[30]논문(The stability-efficiency dilemma: Investigating sequence length warmup for training gpt models.) 은 긴 시퀀스가 (특히) 훈련 초기에는 극단적인 gradient 변화를 만들어 불안정성을 초래한다고 한다.
또한 이 논문에서 제안한 커리큘럼방법이 (뭘말하는 거지, Grow P2?) 더 안정적인 학습을 얻을 수 있도 그래서 더 큰 batch size 와 lr 에 대해 더 효율적인 학습을 할 수 있다. (Appendix E)
4. Related works
최근 연구는 cross-document attention 에 대한 우려가 많다. 예를 들면 Llama-3 등 (section 3.6)
25 논문에서는 concat-and-chunk 의 문제점과 이의 대안으로 approximate bin-packing algorithm 을 제안한다.
(25: Efficient sequence packing without cross-contamination: Accelerating large language models without impacting performance)
sequence length bias 에 대한 연구(58)는 시퀀스 길이 관점에서 train-vs-test의 time distribution shift의 중요성을 말한다.
[6, 62, 24, 33] 는 학습때 모델이 본 길이 이상으로 일반화하는 것에 대해 강조하고 positional encoding의 중요성을 말한다.
[38, 63, 23, 60, 10, 43, 44, 49]는 long context 의 inference 를 가능하게 하는 연구를 한다.
이러한 접근법은 우리의 contribution 과 orthogonal 하고(? 아마 같이 사용할 수 있다는 뜻) pre-training 이후에도 적용가능하다.
dynamic batching 는 다른 도메인 연구에서 많이 나왔다. (비전, seq-to-seq tasks, neural machine translation 등)
효율적인 방법은 길이로 인풋을 sorting 하여 훈련시에는 비슷한 길이의 input 을 batch 로 형성한다. (padding 이후에(?)) batch size 는 input length 에 대해 동적으로 변화한다.
이런 연구들과 다르게 dataset decomposition은 단순히 여러 docs 를 비슷한 길이로 같은 bucket 에 넣지 않는다. 대신에 각 doc 을 여러 subsequence 들로 나누고 여러 bucket 을 생성한다.
이 bucket 에서 샘플링하여 다른 길이의 batch 를 생성한다.
'LLM 관련 논문 정리' 카테고리의 다른 글
RAGAS: Automated Evaluation of Retrieval Augmented Generation (1) | 2024.11.10 |
---|---|
DoRA: Weight-Decomposed Low Rank Adaptation (0) | 2024.05.02 |
LLAMA-2 from the ground up (0) | 2024.02.11 |
SOLAR model paper (1) | 2024.01.13 |
SPoT: Better Frozen Model Adaptation through Soft Prompt Transfer (0) | 2023.11.23 |