NTK-aware dynamic interpolation

2023. 8. 21. 11:36 · LLM 관련 논문 정리
목차
  1. 개념 간단 정리
  2. 구현 설명
  3. 결과

기존 논문에서 나온 RoPE interpolation (혹은 extrapolation) 과 다른 방법으로, 기존 방법이 linear 방법이라면, 지금 소개하는 방법은 dynamic interpolation 이다.

Qwen-7B 등 다양한 LLM 모델에서 적용되고 있고, Huggingface 에서도 구현해놓았다.

기본적으로 RoPE 로 학습된 모델만 있다면 evaluation 에서 적용하는 것이 어렵지 않기 때문에 많이 이용한다.

 

Reddit 에서 처음 소개된 방법이다.

https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

 

From the LocalLLaMA community on Reddit: NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without a

Explore this post and more from the LocalLLaMA community

www.reddit.com

위링크 에서 스레드로 언급되었지만 fine tuning 를 따로 하지 않아도 좋은 성능을 얻을 수 있다고 한.

 

NTK 는 Neural Tangent Kernel 인데, 개념이 조금 어렵기 때문에 이 블로그에서는 설명하지 않으려 한다.

참조 : arXiv:1806.07572

 

개념 간단 정리

하지만 구현 방법은 매우 쉽다!

위의 식은 Transformer 의 positional encoding 이다.

base=base∗αdim/(dim−2) base= base * \alpha ^{dim/(dim-2)} base=base∗αdim/(dim−2)

위에서 base는 10000을 가리킨다.

그리고 alpha 는 NTK-Aware hyper parameter를 말하고, “scale” 이라고 한다.

 

α=(α∗current_sequence_length/original_model_context_length)−(α−1) α = (α * current\_sequence\_length / original\_model\_context\_length) - (α - 1) α=(α∗current_sequence_length/original_model_context_length)−(α−1)

위에서 original model context len 은 모델이 학습할 때 사용한 max seq len (요즘 LLM 은 약 2048) 을 말한다.

current seq len 은 인풋으로 들어가는 실제 문장의 seq length 를 말한다.

 

구현 설명

https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py

 

modeling_qwen.py · Qwen/Qwen-7B-Chat at main

 

huggingface.co

위 모델의 코드를 참조하였습니다.

 

if (
            self.use_dynamic_ntk
            and kv_seq_len == hidden_states.size()[1]
            and not self.training
        ):
    context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
    ntk_alpha = 2 ** math.ceil(context_value) - 1 # 올림
    ntk_alpha = max(ntk_alpha, 1)
else:
    ntk_alpha = self.rotary_emb._ntk_alpha_cached
   
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)

for idx in range(len(rotary_pos_emb)):
    rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)

kv_seq_len 은 current_sequence_length 이고, self.seq_length 는 original_model_context_length 를 나타낸다.

 

 

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        if importlib.util.find_spec("einops") is None:
            raise RuntimeError("einops is required for Rotary Embedding")

        self._rotary_pos_emb_cache = None
        self._seq_len_cached = 0
        self._ntk_alpha_cached = 1.0

    def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
        seqlen = max_seq_len + offset
        if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
            base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
            self.inv_freq = 1.0 / (
                base
                ** (
                    torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
                    / self.dim
                )
            )
            self._seq_len_cached = max(2 * seqlen, 16)
            self._ntk_alpha_cached = ntk_alpha
            seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
            freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
            
            emb = torch.cat((freqs, freqs), dim=-1)
            from einops import rearrange

            emb = rearrange(emb, "n d -> 1 n 1 d")

            cos, sin = emb.cos(), emb.sin()
            self._rotary_pos_emb_cache = [cos, sin]

    def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
        self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
        cos, sin = self._rotary_pos_emb_cache
        return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]

 

결과

NTK 레딧

llama-2 모델을 사용하여 실험을 진행하였다.

NTK 레딧

파인튜닝을 하지 않아도 좋은 성능을 얻을 수 있다는 것은 매우 중요한 특징을 가진다. 

LLM 은 fine-tuning 을 하는 데에도 매우 많은 computational power 가 필요하기 때문이다.

그래프에 나온 결과는 그림과 같이 파인튜닝을 하지 않았다.

안해도 성능이 explode 되지 않는다는 것을 보여준다. 하지만 성능이 높아지려면 파인튜닝해야 한다고 한다.

'LLM 관련 논문 정리' 카테고리의 다른 글

Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning (NIPS, 2208)  (0) 2023.09.30
PEFT (parameter-efficient fine tuning) 정리  (0) 2023.09.18
InstructGPT 상세 리뷰  (0) 2023.08.30
Rotary Position Embedding (RoPE)  (0) 2023.08.21
LoRA (Low-Rank Adaptation of Large Language Models)  (0) 2023.08.10
  1. 개념 간단 정리
  2. 구현 설명
  3. 결과
'LLM 관련 논문 정리' 카테고리의 다른 글
  • PEFT (parameter-efficient fine tuning) 정리
  • InstructGPT 상세 리뷰
  • Rotary Position Embedding (RoPE)
  • LoRA (Low-Rank Adaptation of Large Language Models)
섬섬옥수수
섬섬옥수수
컴공 AI 개발자가 되기 위한 노역입니다
섬섬옥수수
아날로그 인간의 컴공 되기
섬섬옥수수
전체
오늘
어제
  • 분류 전체보기
    • 백준 단계별 코딩 테스트
    • KB 논문 정리
    • Memory network 논문 정리
    • LLM 관련 논문 정리
    • Python 및 Torch 코딩 이모저모
    • Clustering 관련 논문 정리
    • 머신러닝 이모저모
    • 암호학

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • efficient and effective vocabulary expansion towards multilingual large language models
  • 코딩테스트
  • 티스토리챌린지
  • dependency tree
  • 하드웨어
  • ragas
  • 소프트웨어
  • PEFT
  • GIT
  • e5-v
  • 인공지능융합기반시스템개론
  • 오블완
  • 심재형
  • CUDA
  • eeve
  • 백준
  • 이화여대
  • vocabulary expansion
  • 문제풀이
  • constituency tree

최근 댓글

최근 글

hELLO · Designed By 정상우.v4.2.0
섬섬옥수수
NTK-aware dynamic interpolation
상단으로

티스토리툴바

개인정보

  • 티스토리 홈
  • 포럼
  • 로그인

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.