Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum

2024. 11. 16. 22:48 · LLM 관련 논문 정리
목차
  1. Abstract
  2. Introduction
  3. Method
  4. 2.1 Dataset decomposition
  5. 2.2 Variable sequence length training
  6. 3. Experiments and analysis
  7. 3.1 Training efficiency
  8. 3.4 Length-based curriculum
  9. 4. Related works

** Apple 에서 나온 논문이다

https://arxiv.org/pdf/2405.13226

 

원래 decoder-only 모델을 사용할 때는 1개 문장을 학습할 때 max seq length 만큼 padding 하여 학습하는 방법을 사용한다.

하지만 padding 은 학습 비효율성을 매우매우 늘리기 때문에 특히나 LLM 을 학습할 때에는 불리하다.

그래서 요즘은 "packing" 이라는 방법을 사용한다.

packing 이란 padding 을 하지 않기 위해 다음 문장을 padding 부분에 붙여서 최대한 많은 문장을 학습하는 방법이다. 

 

이 논문은 이러한 packing 을 조금 더 개선한 방법이다.

 

Abstract

기존 LLM 이 고정된 seq seq length 를 이용하여 학습하기 위해 다양한 길이의 문장을 랜덤하게 concat 해서 chuncking 하는 방식을 사용한다.

이러한 방식은 seq 안에서 cross-document attention (=서로 다른 문서에 대해 1개 sequence 안에서 청킹으로 concat 되어서 attention 계산되는 것) 을 일으킬 수 있는데, 이는 바람직한 learning signal 이 아니고(?) 계산적으로 효율적이지도 않다.

또한 quadratic 한 계산 비용 증가로, 매우 비싸다.

 

그래서 이 연구에서는 다양한 길이의 seq len 을 학습하는 새로운 기술인 dataset decomposition 을 제안한다.

dataset 을 서로 다른 doc 에서 나온 같은 길이의 시퀀스로 이루어지는 bucket 의 조합으로 분해한다.

학습 동안 다양한 seq len 랑 batch size 를 사용하여 커리큘럼에 따라 모든 bucket 에서 동시에 샘플링한다.

 

(모든 학습 step 에서 고정된 attention cost 를 가지는) concat-and-chunk baseline 과 비교하여, 우리 방법은 각 step 에서 실제 document 길이에 비례한 penalty 를 발생시킨다. 그래서 학습에서 매우 큰 saving 을 얻는다.

 

concat-and-chunck 방법으로 학습한 2k context-length model 과 같은 계산비용으로 8k context-length 1B model 을 학습한다.

web-scale corpus 데이터셋 실험에 대해서는 이 논문에서 제안한 방식이 standard language evaluations 와 long-context benchmarks에서 baseline 에 비해 3배 빠르게 정확도를 얻었다.

긴 시퀀스에서 효과적으로 pre-train 할 뿐 아니라, dataset 크기에 대해 효과적으로 확장할 수 있다.

또한 sequence length 에 대한 분포와 커리큘럼 또한 조명한다.

 

Introduction

웹(아마 크롤링한 데이터셋이라는 뜻..) 데이터셋 : Pile [19], RefinedWeb [42], RedPajama [14], DOLMA [53] 등.

여러 documents 로 이루어져있다 : Wikipedia articles, books , code repositories.

documents 의 각각 길이는 짧은 메세지 부터 매우 긴 책까지 다양하다. 하지만 모델 학습은 1개 batch 당 한정적인 length 만을 지원한다. 효율적인 학습을 위해 chunking 을 많이 사용한다.

 

이 논문에서는 chuncking 의 효과를 확인하고 대안을 제시하고 평가한다. Recent works [39, 34, 55, 56] (=fairseq논문, Roberta 논문, Llama, Llama2 논문) 에서는 concat-and-chunk approach 를 많이 사용한다.

 

이 방법은 (학습전) data preparation 단계에서

우선 random shuffle 하고 tokenized documents 를 다 concatenate 한다 .

 

연속되는 doc 들은 special token <EOT>, 으로 분리되어 doc 간의 범위를 인식할 수 있도록 한다.

 

그 후 target sequence length 단위로 subsequences 를 나눈다(chunk).

 

예를 들어 각각 Llama-1 는 max seq len = 2048 와 Llama-2 는 4096을 가지는데, 이 길이만큼 모델이 pretrained 된다.

 

concat-and-chunk 방법의 단점

Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum

 

  1. 랜덤하게 doc 을 concat 하는 것은 모델이 서로 연관 없는 문서에 대해 attend 하여 다음 새로운 토큰을 예측하도록 할 수 있다. 잘 학습된 모델은 cross-document attention를 피하도록 학습할 수 있지만 (?어떻게?) 명백하게 적용되지는 않으므로 잠재적인 가짜 모델링이 필요하다.
  2. cross-document attention 는 서로 관계없는 토큰 간의 필요 없는 계산을 소비한다. 어텐션은 quadratic complexity 를 가지므로 이는 치명적이다.
  3. document 가 target sequence length보다 짧아도 2개 청크로 나뉠 수 있다. (2개 sequence 경계에 있을 때) 이는 평균 문서 길이에 비해 평균 청크 길이가 작아져서 모델 성능이 저하될 수 있다. (Fig. 3a)

그래서 LLM 학습 연구 들에서는 concat-and-chunk approach 를 발전시킨 방법들을 내고 있다:

 

- document-masking [36](llama 3) : cross-document attention 해결하기 위함

 

- best-fit packing [17](Fewer truncations improve language modeling) : document chunking를 줄이고 random 대신 관련있는 doc 를 concat 하기 위해 [51](In-context pretraining: Language modeling beyond document boundaries).

 

하지만 위 어떤 연구도 이 3가지 문제를 함께 언급하지는 않는다. (= 이 논문에서는 다 다루겠다.)

 

이 논문에서 주장하는 DD 는 concatand-chunk 보다 장점이 있다.

  1. DD 는 간단하고 data preparation 단계 동안 무시할 수 있는 계산 overhead 가 생긴다. 이를 통해 큰 dataset 에도 적용할 수 있다.
  2. 모든 bucket 의 각 시퀀스에 있는 토큰은 동일한 문서에서 가져온 것을 보장하기 때문에 cross document attention을 피할 수 있다.
  3. sequence length distribution 라는 보조 prior knowledge 에 접근하여 학습을 위해 다른 mixtures/curricula 를 사용할 수도 있다.
  4. VSL 학습을 이용하여 학습 시간을 단축할 수 있다. : attention 연산의 quadratic complexity 때문에 1개 최적화 step 의 latency 는 i 가 더 작은 D_i 에서 샘플링할 때 더 작다. (?)

contribution 은

  • DD 소개, DD 의 장점

 

  • 다양한 모델, dataset 을 이용해 eval 함 : 데이터 효율성(> 2배)과 계산 효율성(11%~45%)

 

  • 다른 언어와 long context 에 대한 학습(pre-training) 에서의 sequence length distribution 과 mixture 중요성을 실험으로 보인다. seq 길이를 합성적으로 변경하기 위한 concatenation and chunking 연산의 효과를 보인다.

 

Method

2.1 Dataset decomposition

D(dataset) = {d1, d2, . . . , dn} 이 때 d= tokenized document 가 주어지면

dataset decomposition (DD) 는 D 를 bucket 의 집합 ∪iDi∪_iD_i∪i​Di​ 으로 재생성한다. 이때

 

  1. 각 bucket D_i 는 l_i 길이의 token 의 seq 으로 이루어진다.
  2. s∈Dis ∈ D_is∈Di​ 인 각 시퀀스는 D 에 포함되는 1개 document 의 subsequence 이다.
    (s는 문장 단위가 아니다!)
  3. D 에 있는 각 token 은 정확히 1개의 D_i 안에 있다. (중복이 없다는 뜻인가,,)

(3) each token in D appears in exactly one Di

 

이 decomposition 은 각 unique document 에 들어있는 각 시퀀스 들을 만들어내(그러니까 같은 document 에 포함되는 문장들을 뭉친다는 것인가) 학습 동안 시퀀스 안에서 cross-docuent attention 이 없도록 한다.

 

게다가 주어진 bucket Di 안의 모든 시퀀스 들은 같은 길이 l_i 를 가지고 있어서 batching 하기 좋다.

 

그래서 이 논문에서는 구체적인 decomposition 을 제안한다.

 

이는 li=2il_i = 2^ili​=2i 를 사용하여 효율적인 batch pre-training 을 하면서 원래 오리지날의 문서 sequence length distribution을 최적으로 유지하기 위함이다. (2.2 에서 설명한다.)

 

decomposition 은 document 레벨로 적용하여 기존의 data preparation pipeline 과 잘 결합하여 사용할 수 있도록 한다. 또한 큰 dataset 에도 확장이 쉽다.

 

l=2i1+2i2+...+2ikl = 2^{i_1} + 2^{i_2} + . . . + 2^{i_k}l=2i1​+2i2​+...+2ik​ 는 길이 l 의 binary decomposition 을 진행한다.

 

tokenized document d 를 k개의 인접한 sequences s1,...,sks_1, . . . , s_ks1​,...,sk​ 로 나눈다. 이 때 길이는 각각 2i1,...,2ik2^{i_1}, . . . , 2^{i_k}2i1​,...,2ik​ 이다.

 

길이가 2ij2^{i_j}2ij​  인 sjs_jsj​ 시퀀스는 DijD_{i_j}Dij​​ 에 배정된다. (그림2)

 

Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum

 

각 bucket D_i 는 오리지널 doc d 의 길이가 최소 2i2^i2i 가 되도록 하는 d 에서 추출한 시퀀스들을 포함한다. (오리지널 doc의 길이를 기준, 자를 길이보다 큰 오리지널 문서를 사용한다는 뜻인듯하다.)

 

Fig. 3b 는 RefinedWeb dataset tokens의 분포를 각 bucket 마다 보여준다.

 

D_9 는 길이 512 가 max length 를 가진다.

 

이 논문에서는 또한 토큰을 추출하는 오리지널 doc 의 길이 강조한다.

 

bucket D_i 의 대부분의 토큰들은 doc 길이 l 이 2i≤l<2i+12^i ≤ l < 2^{i+1}2i≤l<2i+1 이 되도록 하는 doc 에서 추출하고, 나머지는 l≥2i+1l ≥ 2^{i+1}l≥2i+1 인 doc 로 이월된다. (=roll over)

 

이는 많이 없는 긴 문서에서 특히 오리지널 문서의 길이를 유지하는 효율을 보여준다.

 

Fig. 3a 에서 오리지널 doc 과 2048, 8192 seq_len 로 concat-and-chunk 된 chunk 의 분포를 나타낸다.

이 논문에서는 또한 최근 연구인 bin-packing approximate 알고리즘의 길이 분포에 대해서도 나타낸다.

 

또한 Fig. 3c 에서는 DD 를 사용할 때와 seq_len 8192 baseline 에서의 context length (pre-training 동안 1개 토큰이 (같은 문서 안에서!) attend 할 수 있는 토큰 개수)의 분포 를 확인할 수 있다. (Appendix F)

 

정적인 dataset 을 출력하는 concat-and-chunk 방법과 다르게 DD 는 prior knowledge 에 따라 seq len 분포를 사용할 수 있고, target task 에 따라 best mixture 를 최적화할 수 있다.

 

2.2 Variable sequence length training

2i2^i2i 길이를 가지는 시퀀스를 포함하는 D_i bucket 이 K 개 있다고 가정한다.

 

b = target batch size (optimization step 마다 사용되는 token 개수)

 

variable sequence length (VSL) training 은 모든 최적화 step 에서 가능한 범위 내의 i 값을 먼저 샘플링하고 그 후 D_i 에서 b/2ib/2^ib/2i 개의 시퀀스를 고른다.

 

D_i 는 2^i 길이의 seq 로 이루어져 있기 때문에 최적화 step 마다 보이는 token 개수는 b 이고 i 의 선택에 독립적이다.

 

⇒ 즉, 가변적인 seq len 을 사용한다는 것은 가변적인 batch size 를 사용한다는 것이다. 이거 한마디면 모든 내용이 이해가 간다!

 

 

LLM 을 vsl 로 학습시키는 것의 장점

  1. 최적화 step 마다 보이는 token 개수 는 변하지 않기 때문에 최적화 방법을 바꿀 필요 없고 baseline 과 같은 hyper-parameters 를 사용할 수 있다.
  2. Section 3.1 처럼 , 고정된 b 사이즈에 대한 (forward+backward)의 1개 최적화 step 을 완료하는 시간이 attention 연산(quadratic cost )때문에 시퀀스 길이에 따라 변한다는 것을 보인다.
  3. VSL training 에서는 모든 최적화 step 에서의 비용은 해당 step 의 sequence length 에 대해 샘플링된 bucket D_i 에 따라 달라진다. 따라서 더 긴 시퀀스의 step 은 더 짧은 시퀀스 step에서 보상받을 수 있다.
  4. VSL 의 sampling component (=모든 최적화 step 에서 어떤 D_i 가 선택될지) 는 서로 다른 sequnce length 의 커리큘럼이 가능하다. Section 3.4에서는 그런 커리큘럼의 중요성과 그에 따라 모델 안정성과 일반화 정확도가 어떻게 달라지는 지를 보인다.

** 그러면 커리큘럼이 무엇인가

 

<aside> 💡 “커리큘럼”이란 뭔가 : Yoshua Bengio의 Curriculum learning 논문.

학습 과정동안 의미있는 순서대로(더 복잡한 순서로) 학습을 시키면 더 효과적이라는 것. 그래서 쉬운 데이터로 학습을 시작하여 점점 전체 데이터를 사용하는 방법이다. → 즉, 학습의 순서에 대한 이야기.

</aside>

 

3. Experiments and analysis

Section 3.5 를 제외한 모든 실험에서 EleutherAI/gpt neox [9] tokenizer (vocabulary size is 50,432) 사용.

 

모델과 학습코드는 OpenLM 사용. (OpenLM-1B 8k 모델 사용)

 

RoPE 사용. f=10,000

[44, 60, 33](Yarn, Effective long-context scaling of foundation models, Scaling laws of rope-based extrapolation) 에서는 frequency 를 늘리고 fine-tuning 하면 좋다는 결과가 있는데, 이 실험에서는 pre-train 부터 늘려서 사용했을 때도 좋다는 것을 확인했다. f=100,000 (Table 4)

 

Evaluation (벤치마크) : LLM Foundry

Task

Commonsense Reasoning (CSR) PIQA-0-shot, COPA-0-shot , OpenBookQA-10-shots
Language Understanding (LU) Lambada-OpenAI , Hellaswag-0-shot , Winograd-3-shots , WinoGrande-5-shots
Reading Comprehension (RC) SQuAD-3-shots [46], BoolQ-0-shot [12], and CoQA-0-shot
World Knowledge (WK) Jeopardy-3-shots [3], ArcEasy-3-shots [13], ArcChallenge-3-
shots [13], and WikiDataQA-3-shots  
Multi-Document Question Answering (MDQA) MDQA-10, MDQA-20, and MDQA-30
TOEFL -
QuALITY -

 

** 아래 흰색은 longer context 를 위한 eval set

 

3.1 Training efficiency

VSL 학습이 concat-and-chunk 보ㄷ 더 높은 throughput 을 가지는지 확인한다.

 

모델 사이즈 (OpenLM-1B/3B/7B)와 context length (2^6 to 2^13) 에 따른 100 bach (size = 8 × 8192)학습에서의 학습 시간을 따진다. : 8 GPUs single node (Appendix C.1)

 

3.4 Length-based curriculum

보통 더 짧은 시퀀스가 긴 시퀀스보다 쉽다 고 생각할 수 있다.

 

VSL 에서도 샘플링 디자인을 통해 curriculum learning 을 쉽게 구현할 수 있다.

 

매 최적화 step 에서 bucket D_i 에서 확률 p_i 로 샘플링하되, b 개 토큰으로 batch 를 대체하지는 않는다. (??뭔소리)

 

bucket 이 비어있다면 샘플링에서 제외한다.

 

이 논문에서는 같은 개수의 토큰이 들어있는 D8, . . . , D13 bucket 에서 "≥ 256" mixture 에 대한 커리큘럼을 연구한다.

각 커리큘럼에서 각 bucket 에서 batch 를 샘플링할 확률(normalized 된 p_i)을 결정한다.

 

커리큘럼은 short 부터 long 시퀀스까지 변화하는 pace 인데, 이 때 p_i 는 linear 하거나, 2의 거듭제곱으로변화하거나, bucket 사이에서는 100 의 거듭제곱으로 변화하는 것 3가지를 고려한다.

 

커리큘럼은 implicit bias 를 일으킬수도 있다. 예를 들면 학습 끝에 긴 시퀀스만 보게되면 learning rate 가 작을 때만 긴 시퀀스 학습이 발생할 수 있다.

 

이런 문제를 해결하기위해 커리큘럼이 주기적으로 적용되는 cyclic curricula 도 확인한다.

 

Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum

 

"Grow-P2"(아마 2의 거듭제곱 방법) curriculum이 가장 최적임을 보인다.

 

추가적인 장점은 훈련 안정성이다.

[30]논문(The stability-efficiency dilemma: Investigating sequence length warmup for training gpt models.) 은 긴 시퀀스가 (특히) 훈련 초기에는 극단적인 gradient 변화를 만들어 불안정성을 초래한다고 한다.

 

또한 이 논문에서 제안한 커리큘럼방법이 (뭘말하는 거지, Grow P2?) 더 안정적인 학습을 얻을 수 있도 그래서 더 큰 batch size 와 lr 에 대해 더 효율적인 학습을 할 수 있다. (Appendix E)

 

4. Related works

최근 연구는 cross-document attention 에 대한 우려가 많다. 예를 들면 Llama-3 등 (section 3.6)

 

25 논문에서는 concat-and-chunk 의 문제점과 이의 대안으로 approximate bin-packing algorithm 을 제안한다.

(25: Efficient sequence packing without cross-contamination: Accelerating large language models without impacting performance)

 

sequence length bias 에 대한 연구(58)는 시퀀스 길이 관점에서 train-vs-test의 time distribution shift의 중요성을 말한다.

 

[6, 62, 24, 33] 는 학습때 모델이 본 길이 이상으로 일반화하는 것에 대해 강조하고 positional encoding의 중요성을 말한다.

 

[38, 63, 23, 60, 10, 43, 44, 49]는 long context 의 inference 를 가능하게 하는 연구를 한다.

 

이러한 접근법은 우리의 contribution 과 orthogonal 하고(? 아마 같이 사용할 수 있다는 뜻) pre-training 이후에도 적용가능하다.

 

dynamic batching 는 다른 도메인 연구에서 많이 나왔다. (비전, seq-to-seq tasks, neural machine translation 등)

효율적인 방법은 길이로 인풋을 sorting 하여 훈련시에는 비슷한 길이의 input 을 batch 로 형성한다. (padding 이후에(?)) batch size 는 input length 에 대해 동적으로 변화한다.

 

이런 연구들과 다르게 dataset decomposition은 단순히 여러 docs 를 비슷한 길이로 같은 bucket 에 넣지 않는다. 대신에 각 doc 을 여러 subsequence 들로 나누고 여러 bucket 을 생성한다.

 

이 bucket 에서 샘플링하여 다른 길이의 batch 를 생성한다.

'LLM 관련 논문 정리' 카테고리의 다른 글

RAGAS 라이브러리 평가지표 설명  (0) 2025.05.10
Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models (EEVE) 논문 정리  (0) 2025.05.05
RAGAS: Automated Evaluation of Retrieval Augmented Generation  (1) 2024.11.10
DoRA: Weight-Decomposed Low Rank Adaptation  (0) 2024.05.02
LLAMA-2 from the ground up  (0) 2024.02.11
  1. Abstract
  2. Introduction
  3. Method
  4. 2.1 Dataset decomposition
  5. 2.2 Variable sequence length training
  6. 3. Experiments and analysis
  7. 3.1 Training efficiency
  8. 3.4 Length-based curriculum
  9. 4. Related works
'LLM 관련 논문 정리' 카테고리의 다른 글
  • RAGAS 라이브러리 평가지표 설명
  • Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models (EEVE) 논문 정리
  • RAGAS: Automated Evaluation of Retrieval Augmented Generation
  • DoRA: Weight-Decomposed Low Rank Adaptation
섬섬옥수수
섬섬옥수수
컴공 AI 개발자가 되기 위한 노역입니다
섬섬옥수수
아날로그 인간의 컴공 되기
섬섬옥수수
전체
오늘
어제
  • 분류 전체보기
    • 백준 단계별 코딩 테스트
    • KB 논문 정리
    • Memory network 논문 정리
    • LLM 관련 논문 정리
    • Python 및 Torch 코딩 이모저모
    • Clustering 관련 논문 정리
    • 머신러닝 이모저모
    • 암호학

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • constituency tree
  • 코딩테스트
  • dependency tree
  • e5-v
  • 티스토리챌린지
  • CUDA
  • ragas
  • 문제풀이
  • vocabulary expansion
  • PEFT
  • 소프트웨어
  • eeve
  • 이화여대
  • GIT
  • 오블완
  • 인공지능융합기반시스템개론
  • 하드웨어
  • 심재형
  • efficient and effective vocabulary expansion towards multilingual large language models
  • 백준

최근 댓글

최근 글

hELLO · Designed By 정상우.v4.2.0
섬섬옥수수
Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
상단으로

티스토리툴바

개인정보

  • 티스토리 홈
  • 포럼
  • 로그인

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.