---
title: "FlashAttention-2를 CuTe로 밑바닥부터 구현하며 겪은 GPU 커널 지옥기"
published: 2026-05-18T22:53:30.000Z
canonical: https://jeff.news/article/2936
---
# FlashAttention-2를 CuTe로 밑바닥부터 구현하며 겪은 GPU 커널 지옥기

글쓴이는 FlashAttention-2를 Triton으로는 이틀 만에 구현했지만, 같은 알고리즘을 NVIDIA CuTe 기반 C++ 커널로 옮기는 데는 몇 주와 약 100시간의 해설 작업이 걸렸다고 말한다. 이 글은 A100에서 GMEM·SMEM 비동기 복사, tiled MMA, swizzling, online softmax, epilogue store까지 생산 수준 커널의 핵심을 한 줄씩 파고드는 튜토리얼에 가깝다.

## Triton 이틀, CuTe 몇 주

- 글쓴이는 FlashAttention-2를 Triton으로는 이틀 만에 구현했다고 함
  - GPU kernel을 처음 만져본 상태였는데도, Triton의 `tl.dot` 같은 고수준 추상화 덕분에 꽤 빠르게 결과가 나옴
  - 그런데 같은 걸 C++ CuTe로 옮기자 상황이 완전히 달라짐
  - 몇 주 동안 CUDA layout, shared memory, swizzling, MMA fragment, async copy를 하나씩 파고들어야 했고, 블로그 작성까지 포함해 약 100시간을 썼다고 밝힘

- 이 글은 FlashAttention-2 논문 요약이 아니라 “생산급 GPU 커널을 실제로 어떻게 짜는가”에 가까움
  - 대상 GPU는 A100, 즉 Ampere 세대임
  - fp16 입력, fp32 accumulation, causal masking·RoPE·dropout·KV-cache 같은 분기는 제거한 단순 attention을 다룸
  - 그래도 핵심 경로는 Tri Dao의 production FA-2와 같은 idiom, 같은 building block을 따라감

- 성능도 장난감 수준으로 끝나지 않음
  - A100에서 head dimension 64·128, sequence length 최대 64K 조건을 다룸
  - 단순화한 커널이 production FA-2 대비 88-105% throughput을 낸다고 밝힘
  - peak 기준 fp16 tensor core utilization은 63%까지 나옴
  - 목적은 새 알고리즘이 아니라, 단순화해도 핵심 성능 구조가 유지된다는 걸 보여주는 쪽임

> [!IMPORTANT]
> 이 글의 제일 센 숫자는 “production FA-2 대비 88-105%”임. 학습용으로 분기를 걷어낸 커널인데도 성능이 비슷하게 나온다는 건, 핵심 병목과 레이아웃 결정을 제대로 잡았다는 뜻임.

## CuTe가 왜 빡센가

- CuTe는 CUTLASS 3.x 안의 layout algebra 코어임
  - CUTLASS는 NVIDIA의 GEMM과 GEMM 인접 커널용 오픈소스 빌딩 블록 라이브러리임
  - CuTe는 shape, stride, layout, tile, thread mapping을 템플릿으로 조합하게 해줌
  - PyTorch tensor처럼 보이는 객체를 쓰지만, 실제로는 훨씬 낮은 레벨에서 메모리와 레지스터 배치를 직접 다룸

- Triton과 CuTe의 철학 차이가 이 글의 큰 축임
  - Triton은 `tl.dot(q, k)`를 쓰면 MMA atom, shared memory swizzle, tiling을 상당 부분 알아서 잡아줌
  - CuTe는 어떤 MMA atom을 쓸지, 어떤 copy atom을 쓸지, SMEM 레이아웃을 어떻게 꼬아 bank conflict를 피할지 직접 결정하게 함
  - 글쓴이 표현대로라면 Triton은 하드웨어를 숨기고, CuTe는 하드웨어를 정면으로 보게 만듦

- 그래서 CuTe를 잘 쓰려면 GPU 구조를 알아야 함
  - thread, warp, thread block, SM, tensor core, shared memory bank 같은 개념이 전부 코드에 드러남
  - register fragment는 tensor처럼 보이지만 실제 주소 지정 가능한 메모리가 아니라 컴파일러가 물리 레지스터에 매핑한 추상 구조임
  - 이 차이를 모르면 layout을 보고도 “왜 이게 맞지?”에서 계속 막힘

## 커널 구조는 단순한데, 디테일이 지옥임

- Attention 자체는 익숙한 그 식임
  - `P = QKᵀ / sqrt(d_h)`
  - `S = softmax(P)`
  - `O = S V`
  - 문제는 이걸 A100 tensor core와 shared memory pipeline에 맞춰 어떻게 쪼개느냐임

- 이 구현의 grid는 batch/head와 Q tile을 기준으로 잡음
  - 각 thread block은 하나의 Q tile을 맡고, Q tile은 block이 끝날 때까지 변하지 않음
  - K와 V tile을 반복해서 가져오며 해당 Q tile의 output tile을 완성함
  - Q, K, V는 global memory에서 shared memory로 staging한 뒤, 다시 register fragment로 올려 MMA에 넣음

- 핵심 loop는 대략 이런 흐름임
  - Q tile을 한 번 global memory에서 shared memory로 가져옴
  - 0번째 K tile을 미리 가져옴
  - K가 도착하면 `QKᵀ` GEMM을 날림
  - 그 사이 V tile을 prefetch함
  - softmax 통계를 갱신하고 `SV` GEMM으로 output accumulator를 업데이트함
  - 다음 K tile prefetch와 현재 V 작업을 겹치며 sequence 전체를 돈 뒤, 마지막 output을 global memory에 저장함

```mermaid
sequenceDiagram
    participant 전역메모리 as 전역 메모리
    participant 공유메모리 as 공유 메모리
    participant 레지스터 as 레지스터 조각
    participant 텐서코어 as 텐서 코어
    participant 소프트맥스 as 온라인 소프트맥스
    전역메모리->>공유메모리: Q 타일과 첫 K 타일 비동기 복사
    공유메모리->>레지스터: LDSM으로 Q/K 조각 로드
    레지스터->>텐서코어: QKᵀ MMA 실행
    전역메모리->>공유메모리: V 타일 prefetch
    텐서코어->>소프트맥스: attention score 전달
    소프트맥스->>레지스터: max/sum 갱신과 누적값 rescale
    레지스터->>텐서코어: S·V MMA 실행
    레지스터->>전역메모리: 최종 O 타일 저장
```

## 성능을 가르는 결정들

- A100에서는 GMEM에서 SMEM으로 `cp.async` 비동기 복사를 쓸 수 있음
  - HBM에서 shared memory로 데이터를 가져오는 동안 tensor core 계산을 진행해 지연을 숨기는 구조임
  - Q는 block 동안 고정이라 한 번만 가져오고, K와 V는 loop마다 prefetch함
  - FA2는 여러 단계 앞까지 깊게 prefetch하지 않고, 한 block 앞 정도만 가져가는 단순한 전략을 쓴다고 설명함

- vectorized load와 coalesced load를 구분해서 설명하는 부분도 실전적임
  - vectorized load는 한 thread가 한 번에 128-bit, 즉 fp16 8개를 가져오는 쪽임
  - coalesced load는 warp 단위로 128-byte 연속 블록을 한 transaction으로 가져오는 쪽임
  - 둘 다 데이터가 연속적이어야 제대로 이득을 봄
  - fp16 하나씩 16-bit로 읽으면 128-bit load 대비 이론상 8배 손해가 날 수 있다고 짚음

- shared memory는 빠르지만 bank conflict가 바로 발목을 잡음
  - NVIDIA shared memory는 32개 bank로 나뉘고, 같은 warp의 여러 thread가 같은 bank의 서로 다른 주소를 동시에 읽으면 직렬화됨
  - 32x32 float 배열을 column 방향으로 읽으면 모든 thread가 같은 bank를 때리는 32-way conflict가 생길 수 있음
  - padding으로 해결할 수도 있지만, FA2에서는 메모리 footprint와 접근 패턴 때문에 swizzling이 더 맞는 선택임

- swizzling은 shared memory 주소를 XOR 기반으로 재배치해 bank conflict를 피하는 기법으로 설명됨
  - 논리적으로는 같은 `(row, col)`을 읽고 쓰지만, 물리 주소는 bank가 골고루 퍼지도록 꼬아둠
  - FA2에서는 fp16 8개짜리 128-bit chunk는 그대로 연속이어야 하므로, 하위 3비트는 보존하고 그 위 bit들을 XOR로 섞는 식의 Swizzle을 씀
  - 대표적으로 `Swizzle<3, 3, 3>` 패턴을 다룸

> [!NOTE]
> CuTe layout은 “데이터가 이렇게 생겼다”가 아니라 “이 주소와 이 thread mapping으로 데이터를 해석하겠다”에 더 가까움. 그래서 row-major, column-major, colex indexing, swizzle이 한꺼번에 나오면 머리가 터지는 게 정상임.

## Online softmax가 거의 공짜가 되는 이유

- 이 커널의 결정적인 layout 선택은 warp를 M 방향으로 타일링하는 것임
  - 각 warp가 `QKᵀ` output의 한 row 전체를 소유하도록 배치함
  - softmax의 row max와 row sum reduction이 warp 내부에서 끝남
  - `__shfl_xor_sync` 같은 warp shuffle primitive로 shared memory staging이나 `__syncthreads()` 없이 처리할 수 있음

- 만약 warp를 N 방향으로 나눴다면 softmax가 훨씬 비싸졌을 가능성이 큼
  - 한 row의 값이 여러 warp에 쪼개지면 reduction이 warp 경계를 넘어가야 함
  - 그러면 shared memory에 중간값을 넣고 thread block sync를 해야 함
  - FlashAttention-1에서 더 부담스러웠던 부분을 FA2는 이 배치로 줄인 셈임

- online softmax는 tile을 돌면서 max와 exp sum을 갱신함
  - 새 score tile의 row max를 구함
  - 기존 max와 비교해 새 max를 만들고, 이전 accumulator를 보정 계수로 rescale함
  - exp sum도 누적 갱신함
  - 전체 attention matrix를 저장하지 않아도 softmax 결과와 output accumulator를 일관되게 유지할 수 있음

## V copy와 이상한 no-op 이야기

- V는 Q, K보다 복사가 더 까다로움
  - `S @ V`에서는 concat dimension이 K/V sequence 축이라, V를 MMA가 기대하는 모양으로 사실상 transpose해서 읽어야 함
  - GMEM에서 SMEM으로는 row-major 그대로 빠르게 복사하고, SMEM에서 register로 올릴 때 `LDSM_T` 계열 transposed load를 사용함

- 글에서 제일 흥미로운 디버깅 포인트는 `sVtNoSwizzle`임
  - production FA2 쪽 코드에는 V fragment shape을 얻기 위해 swizzle을 제거한 듯한 layout이 등장함
  - 처음 보면 “이거 안 하면 뭔가 깨지나?” 싶지만, 글쓴이가 테스트해보니 실제로는 없어도 결과가 같았다고 함
  - hdim 32·96 같은 케이스에서 fragment shape 출력이 이상하게 보이는 이유는 swizzle layout과 `partition_fragment`의 layout 추론 방식 때문인데, 실제 copy와 MMA가 일관되게 같은 mapping을 쓰면 계산은 맞게 돌아감

- 여기서 글쓴이는 더 단순한 선택지도 제안함
  - hdim에 따라 swizzle pattern을 바꾸는 대신, 관련 hdim에 `Swizzle<3,3,3>`를 일관되게 쓰면 shape inconsistency가 줄어든다고 봄
  - 실제 테스트에서도 문제가 없었다고 설명함
  - 이런 부분이 “코드를 따라 쓰는 것”과 “왜 되는지 이해하는 것”의 차이를 보여줌

## 그래서 누가 읽어야 하나

- 이 글은 FlashAttention-2를 쓰는 사람보다 만드는 사람에게 더 맞음
  - PyTorch나 JAX에서 attention 성능만 챙기는 독자라면 과함
  - CUDA kernel, Triton 다음 단계, CUTLASS/CuTe, tensor core 최적화를 공부하는 사람에게는 꽤 귀한 자료임
  - 특히 문서가 reference에 머물러 있고 튜토리얼이 부족한 CuTe 생태계에서 실제 커널 하나를 끝까지 따라가는 글이라는 점이 큼

- 한국 개발자 입장에서도 가치가 높음
  - LLM serving, training 최적화, custom CUDA extension을 다루는 팀이라면 attention kernel의 병목을 이해하는 데 직접 도움이 됨
  - Triton으로 빠르게 실험할지, CuTe/CUDA로 하드웨어를 더 깊게 제어할지 판단하는 기준도 줌
  - “고성능 커널은 왜 이렇게 읽기 어려운가”에 대한 현실적인 답이 들어 있음

---

## 기술 맥락

- 이 글에서 선택한 기술적 방향은 A100에 맞춘 FlashAttention-2 forward kernel을 C++ CuTe로 직접 구현하는 거예요. 왜 굳이 어려운 CuTe를 택했냐면, Triton은 빠르게 성능을 내기 좋지만 shared memory layout, MMA atom, register fragment 배치를 세밀하게 이해하기에는 너무 많은 걸 숨겨주거든요.

- 핵심 트레이드오프는 개발 시간과 하드웨어 제어권이에요. Triton은 글쓴이가 이틀 만에 구현할 정도로 생산성이 높지만, CuTe는 몇 주가 걸렸고 layout algebra 때문에 계속 막혔어요. 대신 CuTe를 통과하면 Ampere의 `cp.async`, LDSM, swizzling, warp reduction이 실제 성능에 어떻게 연결되는지 몸으로 알게 돼요.

- warp를 M 방향으로 배치한 결정은 online softmax 때문이에요. softmax는 row 단위 max와 sum이 필요한데, row가 warp 안에 온전히 들어오면 `__shfl_xor_sync`로 reduction을 끝낼 수 있어요. 반대로 row가 여러 warp에 나뉘면 shared memory와 block sync가 필요해지고, attention kernel에서 바로 병목이 돼요.

- swizzling을 넣은 이유는 shared memory가 무조건 빠른 공간이 아니기 때문이에요. 32개 bank에 접근이 몰리면 load가 직렬화되니까, 논리 좌표는 유지하되 물리 주소를 XOR로 꼬아서 bank를 분산해야 해요. FA2에서는 128-bit fp16 chunk의 연속성도 지켜야 해서 swizzle bit 선택이 성능과 정확성 모두에 영향을 줘요.

- 결과적으로 이 글은 “FlashAttention-2가 빠르다”보다 “왜 빠른 구현은 이런 모양이 되는가”를 설명해요. A100, fp16, fp32 accumulation, head_dim 64·128, sequence length 최대 64K라는 조건 안에서 메모리 이동과 tensor core 계산을 어떻게 겹칠지 하나씩 결정한 기록에 가깝거든요.

## 핵심 포인트

- 단순화한 CuTe 구현이 A100에서 production FA-2 대비 88-105% 처리량에 도달
- CuTe는 Triton처럼 하드웨어를 숨기지 않고 레이아웃·MMA atom·copy 전략을 직접 다루게 함
- 성능의 핵심은 warp를 M 방향으로 배치해 online softmax row reduction을 warp 내부에서 끝내는 것

## 인사이트

이 글의 가치는 FlashAttention-2 설명보다 ‘고성능 GPU 커널이 왜 읽기 어려운가’를 실제 코드 경로로 보여준다는 데 있다. Triton으로 80% 성능을 빨리 얻는 길과, CuTe로 하드웨어를 이해하며 마지막 성능을 쥐어짜는 길의 비용 차이가 아주 노골적으로 드러남.
