---
title: "누스리서치, 51만 토큰 문맥에서 어텐션 학습을 17배 빠르게 만든 라이트하우스 어텐션 공개"
published: 2026-05-19T22:51:33.000Z
canonical: https://jeff.news/article/3022
---
# 누스리서치, 51만 토큰 문맥에서 어텐션 학습을 17배 빠르게 만든 라이트하우스 어텐션 공개

누스리서치가 장문맥 사전학습의 병목인 어텐션 계산을 줄이는 라이트하우스 어텐션을 공개했어. 51만 토큰 문맥에서 단일 B200 기준 forward+backward가 표준 어텐션보다 약 17배 빠르고, 9.8만 토큰 학습에서는 전체 사전학습 속도가 1.4~1.7배 빨라졌다는 결과를 냈어. 핵심은 희소 어텐션 커널을 새로 만들지 않고, 선택된 토큰 묶음을 dense subsequence로 모아 기존 FlashAttention을 그대로 쓰는 점임.

## 장문맥 학습의 병목을 정면으로 찌른 방식

- 누스리서치가 공개한 라이트하우스 어텐션은 장문맥 사전학습에서 가장 비싼 어텐션 계산을 줄이는 선택 기반 계층 어텐션임
  - 단일 B200에서 51만 토큰 문맥 기준 forward+backward latency가 표준 어텐션보다 약 17배 빠름
  - 9.8만 토큰 문맥의 end-to-end pretraining에서는 1.4~1.7배 속도 향상을 보고함
  - 실험은 530M Llama-3, 1만6000 optimizer step, 약 500억 토큰 규모로 진행됨

> [!IMPORTANT]
> 이 결과의 핵심은 “희소 어텐션이 빠르다”가 아니라, 스파스하게 학습한 모델을 다시 dense attention 모델로 복구했을 때 dense-from-scratch와 맞먹거나 더 좋았다는 점임.

- 기존 장문맥 학습의 벽은 어텐션의 제곱 비용임
  - FlashAttention이 상수를 줄여주긴 하지만, 문맥 길이 N이 커질수록 표준 어텐션 비용은 여전히 N²로 올라감
  - 그래서 현실적으로는 “원하는 문맥”이 아니라 “감당 가능한 문맥”에서 학습하게 됨

## 라이트하우스가 기존 sparse attention과 다른 점

- Q, K, V를 모두 같은 방식으로 풀링하는 대칭 구조를 씀
  - 기존 NSA, HISA, InfLLM-v2, DSA, MoBA 같은 접근은 대체로 query는 full resolution에 두고 key/value만 압축하는 비대칭 설계를 택함
  - 라이트하우스는 Q, K, V를 모두 L-level pyramid로 평균 풀링해서 같은 해상도 공간에서 비교하게 만듦
  - 이 선택 덕분에 학습 시 dense attention 호출 비용이 N 곱하기 S가 아니라 S² 쪽으로 줄어듦

- 선택 점수는 학습 가능한 scorer가 아니라 per-head L2 norm으로 계산함
  - 각 피라미드 항목에서 query projection norm과 key projection norm을 뽑아 Top-K를 고름
  - learned scorer head, auxiliary loss, Gumbel-softmax, straight-through estimator가 없음
  - QK 상호작용을 직접 보는 scorer보다 약한 신호인데도 결과가 나왔으니, 저자들은 이걸 selection-based training의 하한선에 가깝게 봄

- 제일 실용적인 포인트는 선택 로직이 attention kernel 바깥에 있다는 것임
  - Top-K로 고른 항목을 contiguous하고 causal order가 맞는 dense subsequence로 gather함
  - 그 뒤에는 일반 FlashAttention을 그대로 실행함
  - 커스텀 sparse attention kernel이 없으니, upstream FlashAttention 개선을 그대로 물려받을 수 있음

```mermaid
sequenceDiagram
    participant 피라미드풀링
    participant 선택로직
    participant 밀집부분수열
    participant 플래시어텐션
    participant 원래시퀀스
    피라미드풀링->>선택로직: Q, K, V를 여러 해상도로 풀링
    선택로직->>선택로직: L2 norm 기반 Top-K 선택
    선택로직->>밀집부분수열: 선택 항목을 causal 순서로 gather/sort
    밀집부분수열->>플래시어텐션: 작은 dense attention 실행
    플래시어텐션->>원래시퀀스: 결과를 base position으로 scatter-back
```

## 네 단계 파이프라인

- 첫 단계는 Q, K, V를 모두 L-level pyramid로 평균 풀링하는 것임
  - level 0은 원래 시퀀스고, level l은 N 나누기 p^l 길이를 가짐
  - 예시로 N=16, L=3, p=2면 16개 base token 위에 8개, 4개 요약 셀이 쌓이는 구조임

- 두 번째는 coarse-to-fine Top-K cascade임
  - 가장 거친 level에서 Top-K를 고르고, 살아남은 항목의 child로 내려가 다시 Top-K를 고름
  - rejected coarse entry를 완전히 버리면 causal mask에 구멍이 생기기 때문에, 일부 rejected coarse entry도 buffer에 유지함
  - 각 level은 대략 p 곱하기 K개 이하 항목만 기여하게 설계됨

- 세 번째는 선택된 Q, K, V triple을 모아 FlashAttention에 넣는 것임
  - gather 후 sort해서 base sequence 기준 causal topology를 유지함
  - 그래서 표준 S 곱하기 S lower-triangular causal mask가 그대로 동작함
  - 구현상 대부분은 torch.gather, torch.sort, FlashAttention 조합이고, torchtitan 위에 신규 파일 2개와 약 600줄 수정으로 구성됨

- 네 번째는 attention output을 원래 위치로 scatter-back하는 것임
  - coarse summary가 대표하는 base position 범위로 출력을 다시 뿌림
  - accumulation kernel은 기본 fp-atomic과 재현용 deterministic integer-atomic 두 가지가 있음
  - deterministic 쪽은 1.2~2배 느려서, 실제 기본값은 fp-atomic임

## sparse로 학습해도 dense 모델로 돌아오나

- 저자들이 가장 중요하게 본 검증은 “학습 끝난 뒤에도 dense attention 모델로 쓸 수 있나”임
  - inference-time sparse method는 dense backbone에 얹어 평가하면 되지만, training-time sparse method는 모델 자체가 sparse approximation에 과적합될 수 있음
  - 그래서 라이트하우스는 2단계 학습 레시피를 씀

- 학습 레시피는 Stage 1 라이트하우스 학습, Stage 2 SDPA 재개 학습임
  - 전체 예산 대부분은 Lighthouse selection을 켜고 학습함
  - 마지막 짧은 tail에서 selection을 끄고 표준 attention으로 이어서 학습함
  - optimizer state와 dataloader continuation은 그대로 유지함

- 세 가지 split 모두 최종적으로 dense-from-scratch 기준을 맞추거나 이김
  - 1만+6000, 1만1000+5000, 1만2000+4000 step 조합을 테스트함
  - resume 직후 loss가 1.12~1.57 nats 튀지만, 약 1000~1500 SDPA step 안에 회복함
  - 최종 loss는 0.6980~0.7102 범위로, dense baseline 0.7237보다 좋았음

> [!NOTE]
> 이 loss spike는 꽤 중요한 신호임. 모델이 처음엔 dense attention 사용에 어색해하지만, 짧은 재개 학습으로 회복된다는 건 sparse 학습이 dense 능력을 완전히 망가뜨리진 않았다는 뜻임.

## 성능과 스케일링

- ablation grid에서 라이트하우스는 75~106 B200-hour를 절약함
  - 전체 wall-clock 기준 1.40~1.69배 속도 향상임
  - Stage 1 throughput은 GPU당 8.4만~12.6만 tokens/s를 유지했고, dense SDPA는 약 4.6만 tokens/s였음
  - SDPA-resume tail은 baseline과 같은 커널을 쓰므로 throughput도 baseline과 같음

- pyramid hyperparameter는 생각보다 예민하지 않았음
  - L은 3, 4, 5를 테스트했고 p는 2, 4, 8을 테스트함
  - 결과는 서로 약 0.02 nats 이내에 모였음
  - 품질을 갈라먹는 knife-edge라기보다는 throughput과 memory reach의 트레이드오프에 가까움

- 10만 토큰을 넘어서면 context parallelism이 필요해짐
  - 530M architecture도 100K context 이후에는 attention 방법과 무관하게 단일 B200에서 OOM이 남
  - 라이트하우스는 pyramid pooling, scoring, Top-K를 shard-local로 처리하고, gather된 결과는 dense subsequence라 ring attention에 그대로 태울 수 있음
  - 32개 Blackwell GPU, 4개 노드, CP degree 8에서 100만 토큰 학습을 inner attention kernel 변경 없이 돌렸음

## 한계와 다음 질문

- autoregressive decoding에는 아직 바로 맞지 않음
  - symmetric Q/K/V pooling은 모든 query가 한 forward pass에 같이 존재한다고 가정함
  - 토큰을 하나씩 생성하는 decoding은 이 가정과 맞지 않아서, 현재는 dense-SDPA resume으로 inference-ready 모델을 만드는 쪽에 의존함

- 비용이 완전한 linear은 아님
  - gather된 subsequence 길이 S에 대해 attention 비용은 S² 곱하기 d임
  - 고정 K에서는 N에 대해 sub-quadratic처럼 보이지만, recall을 유지하려고 K가 N과 함께 커져야 하는 영역은 아직 충분히 규명되지 않음

- 다음 연구 방향도 꽤 실전적임
  - dense-SDPA resume 대신 DSA, NSA, HISA, MoBA 같은 inference-oriented sparse target으로 복구하는 방식
  - layer별 또는 head별 adaptive K budget
  - vision, audio, video처럼 자연스럽게 multi-scale 구조가 있는 데이터로 확장
  - continuous batching, speculative decoding, KV-cache 관리까지 포함한 serving integration

---

## 기술 맥락

- 이 작업에서 중요한 선택은 sparse attention을 커널 안으로 밀어 넣지 않은 거예요. 희소 패턴을 직접 처리하는 커널을 만들면 논문 아이디어는 멋져 보일 수 있지만, 실제 학습에서는 backward, 분산 통신, tensor core 최적화, 재현성까지 전부 새로 챙겨야 하거든요.

- 그래서 라이트하우스는 “무엇을 볼지”만 바깥에서 고르고, “어텐션을 어떻게 계산할지”는 FlashAttention에 맡겨요. 이러면 선택 로직은 연구자가 바꿔볼 수 있고, 계산 커널은 이미 검증된 생태계의 성능 개선을 그대로 받는 구조가 돼요.

- SDPA-resume도 단순한 후처리가 아니라 모델 품질 검증 장치에 가까워요. sparse 방식으로만 잘 도는 특수 모델을 만든 게 아니라, 짧은 표준 어텐션 학습으로 dense 모델에 다시 적응할 수 있는지를 본 거라서 학습 방법으로서의 신뢰도가 훨씬 올라가요.

- context parallelism에서 dense subsequence를 유지한 것도 실전성이 커요. 분산 환경에서는 sparse index를 들고 ring rotation을 맞추는 순간 엔지니어링 난도가 확 뛰는데, 라이트하우스는 gather 결과가 연속 tensor라 기존 ring attention 흐름에 태우기 쉬운 편이에요.

## 핵심 포인트

- Q, K, V를 대칭적으로 피라미드 풀링해 선택 기반 계층 어텐션을 구성함
- 선택 로직은 커널 밖에서 처리하고 실제 어텐션 계산은 기존 FlashAttention으로 수행함
- 530M Llama-3, 1만6000 스텝, 약 500억 토큰 실험에서 dense-from-scratch 기준을 맞추거나 이김
- 스파스 학습 뒤 짧은 표준 어텐션 재개 학습을 붙이면 dense attention 모델로 복구 가능함
- 32개 B200, 컨텍스트 병렬화에서 100만 토큰 학습까지 검증함

## 인사이트

장문맥 모델 학습에서 진짜 골치 아픈 건 멋진 희소 패턴 자체보다, 그걸 학습 커널·분산 학습·최종 dense 모델 품질까지 끌고 가는 일임. 라이트하우스 어텐션은 선택은 밖에서 하고 계산은 검증된 FlashAttention에 맡기는 식으로, 연구 아이디어를 실제 학습 파이프라인에 붙이기 쉽게 만든 게 포인트야.
