본문으로 건너뛰기
피드

FlashAttention-2를 CuTe로 밑바닥부터 구현하며 겪은 GPU 커널 지옥기

ai-ml 약 16분
vote
0
댓글
북마크

글쓴이는 FlashAttention-2를 Triton으로는 이틀 만에 구현했지만, 같은 알고리즘을 NVIDIA CuTe 기반 C++ 커널로 옮기는 데는 몇 주와 약 100시간의 해설 작업이 걸렸다고 말한다. 이 글은 A100에서 GMEM·SMEM 비동기 복사, tiled MMA, swizzling, online softmax, epilogue store까지 생산 수준 커널의 핵심을 한 줄씩 파고드는 튜토리얼에 가깝다.

  • 1

    단순화한 CuTe 구현이 A100에서 production FA-2 대비 88-105% 처리량에 도달

  • 2

    CuTe는 Triton처럼 하드웨어를 숨기지 않고 레이아웃·MMA atom·copy 전략을 직접 다루게 함

  • 3

    성능의 핵심은 warp를 M 방향으로 배치해 online softmax row reduction을 warp 내부에서 끝내는 것

Triton 이틀, CuTe 몇 주

  • 글쓴이는 FlashAttention-2를 Triton으로는 이틀 만에 구현했다고 함

    • GPU kernel을 처음 만져본 상태였는데도, Triton의 tl.dot 같은 고수준 추상화 덕분에 꽤 빠르게 결과가 나옴
    • 그런데 같은 걸 C++ CuTe로 옮기자 상황이 완전히 달라짐
    • 몇 주 동안 CUDA layout, shared memory, swizzling, MMA fragment, async copy를 하나씩 파고들어야 했고, 블로그 작성까지 포함해 약 100시간을 썼다고 밝힘
  • 이 글은 FlashAttention-2 논문 요약이 아니라 “생산급 GPU 커널을 실제로 어떻게 짜는가”에 가까움

    • 대상 GPU는 A100, 즉 Ampere 세대임
    • fp16 입력, fp32 accumulation, causal masking·RoPE·dropout·KV-cache 같은 분기는 제거한 단순 attention을 다룸
    • 그래도 핵심 경로는 Tri Dao의 production FA-2와 같은 idiom, 같은 building block을 따라감
  • 성능도 장난감 수준으로 끝나지 않음

    • A100에서 head dimension 64·128, sequence length 최대 64K 조건을 다룸
    • 단순화한 커널이 production FA-2 대비 88-105% throughput을 낸다고 밝힘
    • peak 기준 fp16 tensor core utilization은 63%까지 나옴
    • 목적은 새 알고리즘이 아니라, 단순화해도 핵심 성능 구조가 유지된다는 걸 보여주는 쪽임

중요

> 이 글의 제일 센 숫자는 “production FA-2 대비 88-105%”임. 학습용으로 분기를 걷어낸 커널인데도 성능이 비슷하게 나온다는 건, 핵심 병목과 레이아웃 결정을 제대로 잡았다는 뜻임.

CuTe가 왜 빡센가

  • CuTe는 CUTLASS 3.x 안의 layout algebra 코어임

    • CUTLASS는 NVIDIA의 GEMM과 GEMM 인접 커널용 오픈소스 빌딩 블록 라이브러리임
    • CuTe는 shape, stride, layout, tile, thread mapping을 템플릿으로 조합하게 해줌
    • PyTorch tensor처럼 보이는 객체를 쓰지만, 실제로는 훨씬 낮은 레벨에서 메모리와 레지스터 배치를 직접 다룸
  • Triton과 CuTe의 철학 차이가 이 글의 큰 축임

    • Triton은 tl.dot(q, k)를 쓰면 MMA atom, shared memory swizzle, tiling을 상당 부분 알아서 잡아줌
    • CuTe는 어떤 MMA atom을 쓸지, 어떤 copy atom을 쓸지, SMEM 레이아웃을 어떻게 꼬아 bank conflict를 피할지 직접 결정하게 함
    • 글쓴이 표현대로라면 Triton은 하드웨어를 숨기고, CuTe는 하드웨어를 정면으로 보게 만듦
  • 그래서 CuTe를 잘 쓰려면 GPU 구조를 알아야 함

    • thread, warp, thread block, SM, tensor core, shared memory bank 같은 개념이 전부 코드에 드러남
    • register fragment는 tensor처럼 보이지만 실제 주소 지정 가능한 메모리가 아니라 컴파일러가 물리 레지스터에 매핑한 추상 구조임
    • 이 차이를 모르면 layout을 보고도 “왜 이게 맞지?”에서 계속 막힘

커널 구조는 단순한데, 디테일이 지옥임

  • Attention 자체는 익숙한 그 식임

    • P = QKᵀ / sqrt(d_h)
    • S = softmax(P)
    • O = S V
    • 문제는 이걸 A100 tensor core와 shared memory pipeline에 맞춰 어떻게 쪼개느냐임
  • 이 구현의 grid는 batch/head와 Q tile을 기준으로 잡음

    • 각 thread block은 하나의 Q tile을 맡고, Q tile은 block이 끝날 때까지 변하지 않음
    • K와 V tile을 반복해서 가져오며 해당 Q tile의 output tile을 완성함
    • Q, K, V는 global memory에서 shared memory로 staging한 뒤, 다시 register fragment로 올려 MMA에 넣음
  • 핵심 loop는 대략 이런 흐름임

    • Q tile을 한 번 global memory에서 shared memory로 가져옴
    • 0번째 K tile을 미리 가져옴
    • K가 도착하면 QKᵀ GEMM을 날림
    • 그 사이 V tile을 prefetch함
    • softmax 통계를 갱신하고 SV GEMM으로 output accumulator를 업데이트함
    • 다음 K tile prefetch와 현재 V 작업을 겹치며 sequence 전체를 돈 뒤, 마지막 output을 global memory에 저장함
sequenceDiagram
    participant 전역메모리 as 전역 메모리
    participant 공유메모리 as 공유 메모리
    participant 레지스터 as 레지스터 조각
    participant 텐서코어 as 텐서 코어
    participant 소프트맥스 as 온라인 소프트맥스
    전역메모리->>공유메모리: Q 타일과 첫 K 타일 비동기 복사
    공유메모리->>레지스터: LDSM으로 Q/K 조각 로드
    레지스터->>텐서코어: QKᵀ MMA 실행
    전역메모리->>공유메모리: V 타일 prefetch
    텐서코어->>소프트맥스: attention score 전달
    소프트맥스->>레지스터: max/sum 갱신과 누적값 rescale
    레지스터->>텐서코어: S·V MMA 실행
    레지스터->>전역메모리: 최종 O 타일 저장

성능을 가르는 결정들

  • A100에서는 GMEM에서 SMEM으로 cp.async 비동기 복사를 쓸 수 있음

    • HBM에서 shared memory로 데이터를 가져오는 동안 tensor core 계산을 진행해 지연을 숨기는 구조임
    • Q는 block 동안 고정이라 한 번만 가져오고, K와 V는 loop마다 prefetch함
    • FA2는 여러 단계 앞까지 깊게 prefetch하지 않고, 한 block 앞 정도만 가져가는 단순한 전략을 쓴다고 설명함
  • vectorized load와 coalesced load를 구분해서 설명하는 부분도 실전적임

    • vectorized load는 한 thread가 한 번에 128-bit, 즉 fp16 8개를 가져오는 쪽임
    • coalesced load는 warp 단위로 128-byte 연속 블록을 한 transaction으로 가져오는 쪽임
    • 둘 다 데이터가 연속적이어야 제대로 이득을 봄
    • fp16 하나씩 16-bit로 읽으면 128-bit load 대비 이론상 8배 손해가 날 수 있다고 짚음
  • shared memory는 빠르지만 bank conflict가 바로 발목을 잡음

    • NVIDIA shared memory는 32개 bank로 나뉘고, 같은 warp의 여러 thread가 같은 bank의 서로 다른 주소를 동시에 읽으면 직렬화됨
    • 32x32 float 배열을 column 방향으로 읽으면 모든 thread가 같은 bank를 때리는 32-way conflict가 생길 수 있음
    • padding으로 해결할 수도 있지만, FA2에서는 메모리 footprint와 접근 패턴 때문에 swizzling이 더 맞는 선택임
  • swizzling은 shared memory 주소를 XOR 기반으로 재배치해 bank conflict를 피하는 기법으로 설명됨

    • 논리적으로는 같은 (row, col)을 읽고 쓰지만, 물리 주소는 bank가 골고루 퍼지도록 꼬아둠
    • FA2에서는 fp16 8개짜리 128-bit chunk는 그대로 연속이어야 하므로, 하위 3비트는 보존하고 그 위 bit들을 XOR로 섞는 식의 Swizzle을 씀
    • 대표적으로 Swizzle<3, 3, 3> 패턴을 다룸

ℹ️참고

> CuTe layout은 “데이터가 이렇게 생겼다”가 아니라 “이 주소와 이 thread mapping으로 데이터를 해석하겠다”에 더 가까움. 그래서 row-major, column-major, colex indexing, swizzle이 한꺼번에 나오면 머리가 터지는 게 정상임.

Online softmax가 거의 공짜가 되는 이유

  • 이 커널의 결정적인 layout 선택은 warp를 M 방향으로 타일링하는 것임

    • 각 warp가 QKᵀ output의 한 row 전체를 소유하도록 배치함
    • softmax의 row max와 row sum reduction이 warp 내부에서 끝남
    • __shfl_xor_sync 같은 warp shuffle primitive로 shared memory staging이나 __syncthreads() 없이 처리할 수 있음
  • 만약 warp를 N 방향으로 나눴다면 softmax가 훨씬 비싸졌을 가능성이 큼

    • 한 row의 값이 여러 warp에 쪼개지면 reduction이 warp 경계를 넘어가야 함
    • 그러면 shared memory에 중간값을 넣고 thread block sync를 해야 함
    • FlashAttention-1에서 더 부담스러웠던 부분을 FA2는 이 배치로 줄인 셈임
  • online softmax는 tile을 돌면서 max와 exp sum을 갱신함

    • 새 score tile의 row max를 구함
    • 기존 max와 비교해 새 max를 만들고, 이전 accumulator를 보정 계수로 rescale함
    • exp sum도 누적 갱신함
    • 전체 attention matrix를 저장하지 않아도 softmax 결과와 output accumulator를 일관되게 유지할 수 있음

V copy와 이상한 no-op 이야기

  • V는 Q, K보다 복사가 더 까다로움

    • S @ V에서는 concat dimension이 K/V sequence 축이라, V를 MMA가 기대하는 모양으로 사실상 transpose해서 읽어야 함
    • GMEM에서 SMEM으로는 row-major 그대로 빠르게 복사하고, SMEM에서 register로 올릴 때 LDSM_T 계열 transposed load를 사용함
  • 글에서 제일 흥미로운 디버깅 포인트는 sVtNoSwizzle

    • production FA2 쪽 코드에는 V fragment shape을 얻기 위해 swizzle을 제거한 듯한 layout이 등장함
    • 처음 보면 “이거 안 하면 뭔가 깨지나?” 싶지만, 글쓴이가 테스트해보니 실제로는 없어도 결과가 같았다고 함
    • hdim 32·96 같은 케이스에서 fragment shape 출력이 이상하게 보이는 이유는 swizzle layout과 partition_fragment의 layout 추론 방식 때문인데, 실제 copy와 MMA가 일관되게 같은 mapping을 쓰면 계산은 맞게 돌아감
  • 여기서 글쓴이는 더 단순한 선택지도 제안함

    • hdim에 따라 swizzle pattern을 바꾸는 대신, 관련 hdim에 Swizzle<3,3,3>를 일관되게 쓰면 shape inconsistency가 줄어든다고 봄
    • 실제 테스트에서도 문제가 없었다고 설명함
    • 이런 부분이 “코드를 따라 쓰는 것”과 “왜 되는지 이해하는 것”의 차이를 보여줌

그래서 누가 읽어야 하나

  • 이 글은 FlashAttention-2를 쓰는 사람보다 만드는 사람에게 더 맞음

    • PyTorch나 JAX에서 attention 성능만 챙기는 독자라면 과함
    • CUDA kernel, Triton 다음 단계, CUTLASS/CuTe, tensor core 최적화를 공부하는 사람에게는 꽤 귀한 자료임
    • 특히 문서가 reference에 머물러 있고 튜토리얼이 부족한 CuTe 생태계에서 실제 커널 하나를 끝까지 따라가는 글이라는 점이 큼
  • 한국 개발자 입장에서도 가치가 높음

    • LLM serving, training 최적화, custom CUDA extension을 다루는 팀이라면 attention kernel의 병목을 이해하는 데 직접 도움이 됨
    • Triton으로 빠르게 실험할지, CuTe/CUDA로 하드웨어를 더 깊게 제어할지 판단하는 기준도 줌
    • “고성능 커널은 왜 이렇게 읽기 어려운가”에 대한 현실적인 답이 들어 있음

기술 맥락

  • 이 글에서 선택한 기술적 방향은 A100에 맞춘 FlashAttention-2 forward kernel을 C++ CuTe로 직접 구현하는 거예요. 왜 굳이 어려운 CuTe를 택했냐면, Triton은 빠르게 성능을 내기 좋지만 shared memory layout, MMA atom, register fragment 배치를 세밀하게 이해하기에는 너무 많은 걸 숨겨주거든요.

  • 핵심 트레이드오프는 개발 시간과 하드웨어 제어권이에요. Triton은 글쓴이가 이틀 만에 구현할 정도로 생산성이 높지만, CuTe는 몇 주가 걸렸고 layout algebra 때문에 계속 막혔어요. 대신 CuTe를 통과하면 Ampere의 cp.async, LDSM, swizzling, warp reduction이 실제 성능에 어떻게 연결되는지 몸으로 알게 돼요.

  • warp를 M 방향으로 배치한 결정은 online softmax 때문이에요. softmax는 row 단위 max와 sum이 필요한데, row가 warp 안에 온전히 들어오면 __shfl_xor_sync로 reduction을 끝낼 수 있어요. 반대로 row가 여러 warp에 나뉘면 shared memory와 block sync가 필요해지고, attention kernel에서 바로 병목이 돼요.

  • swizzling을 넣은 이유는 shared memory가 무조건 빠른 공간이 아니기 때문이에요. 32개 bank에 접근이 몰리면 load가 직렬화되니까, 논리 좌표는 유지하되 물리 주소를 XOR로 꼬아서 bank를 분산해야 해요. FA2에서는 128-bit fp16 chunk의 연속성도 지켜야 해서 swizzle bit 선택이 성능과 정확성 모두에 영향을 줘요.

  • 결과적으로 이 글은 “FlashAttention-2가 빠르다”보다 “왜 빠른 구현은 이런 모양이 되는가”를 설명해요. A100, fp16, fp32 accumulation, head_dim 64·128, sequence length 최대 64K라는 조건 안에서 메모리 이동과 tensor core 계산을 어떻게 겹칠지 하나씩 결정한 기록에 가깝거든요.

이 글의 가치는 FlashAttention-2 설명보다 ‘고성능 GPU 커널이 왜 읽기 어려운가’를 실제 코드 경로로 보여준다는 데 있다. Triton으로 80% 성능을 빨리 얻는 길과, CuTe로 하드웨어를 이해하며 마지막 성능을 쥐어짜는 길의 비용 차이가 아주 노골적으로 드러남.

댓글

댓글

댓글을 불러오는 중...

ai-ml

유튜브, AI 생성 영상에 자동 라벨 붙인다

유튜브가 사실적으로 보이거나 의미 있게 AI로 변경·생성된 콘텐츠에 더 눈에 띄는 라벨을 적용하고, 제작자가 AI 사용 여부를 밝히지 않아도 내부 신호로 감지되면 자동 라벨을 붙이겠다고 밝혔다. 다만 라벨만으로 추천 노출이나 수익화 자격이 바뀌지는 않으며, 제작자는 YouTube Studio에서 잘못된 판정을 수정할 수 있다.

ai-ml

테크 CEO들의 'AI 만능론', 숫자는 아직 그렇게 말하지 않는다

테크 업계에서 AI를 이유로 한 대규모 감원과 조직 재편이 이어지는 가운데, Box 창업자 애런 레비는 CEO들이 실제 업무의 마지막 1마일을 모른 채 AI 에이전트의 능력을 과대평가하고 있다고 지적했다. 2026년 첫 5개월 동안 이미 11만5430명이 해고됐고, 여러 연구는 AI 도입이 체감 생산성만큼 실제 생산성을 끌어올렸다는 근거가 아직 약하다고 말한다.

ai-ml

오픈AI와 앤트로픽, 코딩 에이전트로 드디어 돈 되는 시장을 찾은 듯

사이먼 윌리슨은 오픈AI와 앤트로픽이 코딩 에이전트와 기업용 과금으로 진짜 제품-시장 적합성을 찾았다고 봐. 개인 구독자에게는 월 100달러 플랜이 싸게 느껴지지만, 기업 고객은 이제 사용량 기준 토큰 가격을 그대로 내기 시작했고 이게 대형 고객 예산을 빠르게 흔들고 있다는 얘기야.

ai-ml

컴팔과 GMI 클라우드, 대규모 추론용 AI 인프라 구축 협력

컴팔이 실리콘밸리 기반 AI 인프라 기업 GMI 클라우드와 협력해 대규모 추론과 에이전틱 AI 워크로드에 맞춘 GPU 서버 인프라를 구축한다고 발표했어. COMPUTEX 2026에서는 NVIDIA HGX B300을 지원하는 Compal SGX30-2 같은 고성능 AI 서버 플랫폼도 선보일 예정이야.

ai-ml

AI 쓰면 편해진다더니, 직장인들은 ‘AI 과부하’에 지쳐가는 중

국내 직장인들이 AI 전환 압박, AI 답변 검증 부담, 대체 불안 때문에 피로감을 호소하고 있어. 중앙일보 설문에서는 5284명 중 31.6%가 ‘AI 답변 검증에 시간이 더 걸릴 때’를 가장 지치는 순간으로 꼽았고, 기업들은 무작정 AI 사용량을 밀어붙이는 방식에서 업무 방식 재설계로 넘어가야 한다는 지적이 나와.