본문으로 건너뛰기
피드

누스리서치, 51만 토큰 문맥에서 어텐션 학습을 17배 빠르게 만든 라이트하우스 어텐션 공개

ai-ml 약 13분
vote
0
댓글
북마크

누스리서치가 장문맥 사전학습의 병목인 어텐션 계산을 줄이는 라이트하우스 어텐션을 공개했어. 51만 토큰 문맥에서 단일 B200 기준 forward+backward가 표준 어텐션보다 약 17배 빠르고, 9.8만 토큰 학습에서는 전체 사전학습 속도가 1.4~1.7배 빨라졌다는 결과를 냈어. 핵심은 희소 어텐션 커널을 새로 만들지 않고, 선택된 토큰 묶음을 dense subsequence로 모아 기존 FlashAttention을 그대로 쓰는 점임.

  • 1

    Q, K, V를 대칭적으로 피라미드 풀링해 선택 기반 계층 어텐션을 구성함

  • 2

    선택 로직은 커널 밖에서 처리하고 실제 어텐션 계산은 기존 FlashAttention으로 수행함

  • 3

    530M Llama-3, 1만6000 스텝, 약 500억 토큰 실험에서 dense-from-scratch 기준을 맞추거나 이김

  • 4

    스파스 학습 뒤 짧은 표준 어텐션 재개 학습을 붙이면 dense attention 모델로 복구 가능함

  • 5

    32개 B200, 컨텍스트 병렬화에서 100만 토큰 학습까지 검증함

장문맥 학습의 병목을 정면으로 찌른 방식

  • 누스리서치가 공개한 라이트하우스 어텐션은 장문맥 사전학습에서 가장 비싼 어텐션 계산을 줄이는 선택 기반 계층 어텐션임
    • 단일 B200에서 51만 토큰 문맥 기준 forward+backward latency가 표준 어텐션보다 약 17배 빠름
    • 9.8만 토큰 문맥의 end-to-end pretraining에서는 1.4~1.7배 속도 향상을 보고함
    • 실험은 530M Llama-3, 1만6000 optimizer step, 약 500억 토큰 규모로 진행됨

중요

> 이 결과의 핵심은 “희소 어텐션이 빠르다”가 아니라, 스파스하게 학습한 모델을 다시 dense attention 모델로 복구했을 때 dense-from-scratch와 맞먹거나 더 좋았다는 점임.

  • 기존 장문맥 학습의 벽은 어텐션의 제곱 비용임
    • FlashAttention이 상수를 줄여주긴 하지만, 문맥 길이 N이 커질수록 표준 어텐션 비용은 여전히 N²로 올라감
    • 그래서 현실적으로는 “원하는 문맥”이 아니라 “감당 가능한 문맥”에서 학습하게 됨

라이트하우스가 기존 sparse attention과 다른 점

  • Q, K, V를 모두 같은 방식으로 풀링하는 대칭 구조를 씀

    • 기존 NSA, HISA, InfLLM-v2, DSA, MoBA 같은 접근은 대체로 query는 full resolution에 두고 key/value만 압축하는 비대칭 설계를 택함
    • 라이트하우스는 Q, K, V를 모두 L-level pyramid로 평균 풀링해서 같은 해상도 공간에서 비교하게 만듦
    • 이 선택 덕분에 학습 시 dense attention 호출 비용이 N 곱하기 S가 아니라 S² 쪽으로 줄어듦
  • 선택 점수는 학습 가능한 scorer가 아니라 per-head L2 norm으로 계산함

    • 각 피라미드 항목에서 query projection norm과 key projection norm을 뽑아 Top-K를 고름
    • learned scorer head, auxiliary loss, Gumbel-softmax, straight-through estimator가 없음
    • QK 상호작용을 직접 보는 scorer보다 약한 신호인데도 결과가 나왔으니, 저자들은 이걸 selection-based training의 하한선에 가깝게 봄
  • 제일 실용적인 포인트는 선택 로직이 attention kernel 바깥에 있다는 것임

    • Top-K로 고른 항목을 contiguous하고 causal order가 맞는 dense subsequence로 gather함
    • 그 뒤에는 일반 FlashAttention을 그대로 실행함
    • 커스텀 sparse attention kernel이 없으니, upstream FlashAttention 개선을 그대로 물려받을 수 있음
sequenceDiagram
    participant 피라미드풀링
    participant 선택로직
    participant 밀집부분수열
    participant 플래시어텐션
    participant 원래시퀀스
    피라미드풀링->>선택로직: Q, K, V를 여러 해상도로 풀링
    선택로직->>선택로직: L2 norm 기반 Top-K 선택
    선택로직->>밀집부분수열: 선택 항목을 causal 순서로 gather/sort
    밀집부분수열->>플래시어텐션: 작은 dense attention 실행
    플래시어텐션->>원래시퀀스: 결과를 base position으로 scatter-back

네 단계 파이프라인

  • 첫 단계는 Q, K, V를 모두 L-level pyramid로 평균 풀링하는 것임

    • level 0은 원래 시퀀스고, level l은 N 나누기 p^l 길이를 가짐
    • 예시로 N=16, L=3, p=2면 16개 base token 위에 8개, 4개 요약 셀이 쌓이는 구조임
  • 두 번째는 coarse-to-fine Top-K cascade임

    • 가장 거친 level에서 Top-K를 고르고, 살아남은 항목의 child로 내려가 다시 Top-K를 고름
    • rejected coarse entry를 완전히 버리면 causal mask에 구멍이 생기기 때문에, 일부 rejected coarse entry도 buffer에 유지함
    • 각 level은 대략 p 곱하기 K개 이하 항목만 기여하게 설계됨
  • 세 번째는 선택된 Q, K, V triple을 모아 FlashAttention에 넣는 것임

    • gather 후 sort해서 base sequence 기준 causal topology를 유지함
    • 그래서 표준 S 곱하기 S lower-triangular causal mask가 그대로 동작함
    • 구현상 대부분은 torch.gather, torch.sort, FlashAttention 조합이고, torchtitan 위에 신규 파일 2개와 약 600줄 수정으로 구성됨
  • 네 번째는 attention output을 원래 위치로 scatter-back하는 것임

    • coarse summary가 대표하는 base position 범위로 출력을 다시 뿌림
    • accumulation kernel은 기본 fp-atomic과 재현용 deterministic integer-atomic 두 가지가 있음
    • deterministic 쪽은 1.2~2배 느려서, 실제 기본값은 fp-atomic임

sparse로 학습해도 dense 모델로 돌아오나

  • 저자들이 가장 중요하게 본 검증은 “학습 끝난 뒤에도 dense attention 모델로 쓸 수 있나”임

    • inference-time sparse method는 dense backbone에 얹어 평가하면 되지만, training-time sparse method는 모델 자체가 sparse approximation에 과적합될 수 있음
    • 그래서 라이트하우스는 2단계 학습 레시피를 씀
  • 학습 레시피는 Stage 1 라이트하우스 학습, Stage 2 SDPA 재개 학습임

    • 전체 예산 대부분은 Lighthouse selection을 켜고 학습함
    • 마지막 짧은 tail에서 selection을 끄고 표준 attention으로 이어서 학습함
    • optimizer state와 dataloader continuation은 그대로 유지함
  • 세 가지 split 모두 최종적으로 dense-from-scratch 기준을 맞추거나 이김

    • 1만+6000, 1만1000+5000, 1만2000+4000 step 조합을 테스트함
    • resume 직후 loss가 1.121.57 nats 튀지만, 약 10001500 SDPA step 안에 회복함
    • 최종 loss는 0.6980~0.7102 범위로, dense baseline 0.7237보다 좋았음

ℹ️참고

> 이 loss spike는 꽤 중요한 신호임. 모델이 처음엔 dense attention 사용에 어색해하지만, 짧은 재개 학습으로 회복된다는 건 sparse 학습이 dense 능력을 완전히 망가뜨리진 않았다는 뜻임.

성능과 스케일링

  • ablation grid에서 라이트하우스는 75~106 B200-hour를 절약함

    • 전체 wall-clock 기준 1.40~1.69배 속도 향상임
    • Stage 1 throughput은 GPU당 8.4만~12.6만 tokens/s를 유지했고, dense SDPA는 약 4.6만 tokens/s였음
    • SDPA-resume tail은 baseline과 같은 커널을 쓰므로 throughput도 baseline과 같음
  • pyramid hyperparameter는 생각보다 예민하지 않았음

    • L은 3, 4, 5를 테스트했고 p는 2, 4, 8을 테스트함
    • 결과는 서로 약 0.02 nats 이내에 모였음
    • 품질을 갈라먹는 knife-edge라기보다는 throughput과 memory reach의 트레이드오프에 가까움
  • 10만 토큰을 넘어서면 context parallelism이 필요해짐

    • 530M architecture도 100K context 이후에는 attention 방법과 무관하게 단일 B200에서 OOM이 남
    • 라이트하우스는 pyramid pooling, scoring, Top-K를 shard-local로 처리하고, gather된 결과는 dense subsequence라 ring attention에 그대로 태울 수 있음
    • 32개 Blackwell GPU, 4개 노드, CP degree 8에서 100만 토큰 학습을 inner attention kernel 변경 없이 돌렸음

한계와 다음 질문

  • autoregressive decoding에는 아직 바로 맞지 않음

    • symmetric Q/K/V pooling은 모든 query가 한 forward pass에 같이 존재한다고 가정함
    • 토큰을 하나씩 생성하는 decoding은 이 가정과 맞지 않아서, 현재는 dense-SDPA resume으로 inference-ready 모델을 만드는 쪽에 의존함
  • 비용이 완전한 linear은 아님

    • gather된 subsequence 길이 S에 대해 attention 비용은 S² 곱하기 d임
    • 고정 K에서는 N에 대해 sub-quadratic처럼 보이지만, recall을 유지하려고 K가 N과 함께 커져야 하는 영역은 아직 충분히 규명되지 않음
  • 다음 연구 방향도 꽤 실전적임

    • dense-SDPA resume 대신 DSA, NSA, HISA, MoBA 같은 inference-oriented sparse target으로 복구하는 방식
    • layer별 또는 head별 adaptive K budget
    • vision, audio, video처럼 자연스럽게 multi-scale 구조가 있는 데이터로 확장
    • continuous batching, speculative decoding, KV-cache 관리까지 포함한 serving integration

기술 맥락

  • 이 작업에서 중요한 선택은 sparse attention을 커널 안으로 밀어 넣지 않은 거예요. 희소 패턴을 직접 처리하는 커널을 만들면 논문 아이디어는 멋져 보일 수 있지만, 실제 학습에서는 backward, 분산 통신, tensor core 최적화, 재현성까지 전부 새로 챙겨야 하거든요.

  • 그래서 라이트하우스는 “무엇을 볼지”만 바깥에서 고르고, “어텐션을 어떻게 계산할지”는 FlashAttention에 맡겨요. 이러면 선택 로직은 연구자가 바꿔볼 수 있고, 계산 커널은 이미 검증된 생태계의 성능 개선을 그대로 받는 구조가 돼요.

  • SDPA-resume도 단순한 후처리가 아니라 모델 품질 검증 장치에 가까워요. sparse 방식으로만 잘 도는 특수 모델을 만든 게 아니라, 짧은 표준 어텐션 학습으로 dense 모델에 다시 적응할 수 있는지를 본 거라서 학습 방법으로서의 신뢰도가 훨씬 올라가요.

  • context parallelism에서 dense subsequence를 유지한 것도 실전성이 커요. 분산 환경에서는 sparse index를 들고 ring rotation을 맞추는 순간 엔지니어링 난도가 확 뛰는데, 라이트하우스는 gather 결과가 연속 tensor라 기존 ring attention 흐름에 태우기 쉬운 편이에요.

장문맥 모델 학습에서 진짜 골치 아픈 건 멋진 희소 패턴 자체보다, 그걸 학습 커널·분산 학습·최종 dense 모델 품질까지 끌고 가는 일임. 라이트하우스 어텐션은 선택은 밖에서 하고 계산은 검증된 FlashAttention에 맡기는 식으로, 연구 아이디어를 실제 학습 파이프라인에 붙이기 쉽게 만든 게 포인트야.

댓글

댓글

댓글을 불러오는 중...

ai-ml

유튜브, AI 생성 영상에 자동 라벨 붙인다

유튜브가 사실적으로 보이거나 의미 있게 AI로 변경·생성된 콘텐츠에 더 눈에 띄는 라벨을 적용하고, 제작자가 AI 사용 여부를 밝히지 않아도 내부 신호로 감지되면 자동 라벨을 붙이겠다고 밝혔다. 다만 라벨만으로 추천 노출이나 수익화 자격이 바뀌지는 않으며, 제작자는 YouTube Studio에서 잘못된 판정을 수정할 수 있다.

ai-ml

테크 CEO들의 'AI 만능론', 숫자는 아직 그렇게 말하지 않는다

테크 업계에서 AI를 이유로 한 대규모 감원과 조직 재편이 이어지는 가운데, Box 창업자 애런 레비는 CEO들이 실제 업무의 마지막 1마일을 모른 채 AI 에이전트의 능력을 과대평가하고 있다고 지적했다. 2026년 첫 5개월 동안 이미 11만5430명이 해고됐고, 여러 연구는 AI 도입이 체감 생산성만큼 실제 생산성을 끌어올렸다는 근거가 아직 약하다고 말한다.

ai-ml

오픈AI와 앤트로픽, 코딩 에이전트로 드디어 돈 되는 시장을 찾은 듯

사이먼 윌리슨은 오픈AI와 앤트로픽이 코딩 에이전트와 기업용 과금으로 진짜 제품-시장 적합성을 찾았다고 봐. 개인 구독자에게는 월 100달러 플랜이 싸게 느껴지지만, 기업 고객은 이제 사용량 기준 토큰 가격을 그대로 내기 시작했고 이게 대형 고객 예산을 빠르게 흔들고 있다는 얘기야.

ai-ml

컴팔과 GMI 클라우드, 대규모 추론용 AI 인프라 구축 협력

컴팔이 실리콘밸리 기반 AI 인프라 기업 GMI 클라우드와 협력해 대규모 추론과 에이전틱 AI 워크로드에 맞춘 GPU 서버 인프라를 구축한다고 발표했어. COMPUTEX 2026에서는 NVIDIA HGX B300을 지원하는 Compal SGX30-2 같은 고성능 AI 서버 플랫폼도 선보일 예정이야.

ai-ml

AI 쓰면 편해진다더니, 직장인들은 ‘AI 과부하’에 지쳐가는 중

국내 직장인들이 AI 전환 압박, AI 답변 검증 부담, 대체 불안 때문에 피로감을 호소하고 있어. 중앙일보 설문에서는 5284명 중 31.6%가 ‘AI 답변 검증에 시간이 더 걸릴 때’를 가장 지치는 순간으로 꼽았고, 기업들은 무작정 AI 사용량을 밀어붙이는 방식에서 업무 방식 재설계로 넘어가야 한다는 지적이 나와.