---
title: "Pyrefly가 PyTorch 텐서 shape를 정적 타입으로 추적한다"
published: 2026-05-11T21:58:14.000Z
canonical: https://jeff.news/article/2600
---
# Pyrefly가 PyTorch 텐서 shape를 정적 타입으로 추적한다

Pyrefly가 실험 기능으로 PyTorch 모델의 텐서 shape를 정적 타입 시스템에서 추적하는 기능을 공개했다. NanoGPT 같은 모델에서도 중간 텐서 shape를 에디터 인레이 힌트로 보여주고, shape 변환 오류를 실행 전에 잡는 것이 목표다.

## PyTorch 코드에서 제일 귀찮은 버그를 정적으로 잡겠다는 이야기

- Pyrefly가 PyTorch 텐서 shape 추적 기능을 실험적으로 공개함
  - 목표는 모델을 실행하지 않아도 각 중간 텐서의 shape를 에디터에서 바로 보는 것임
  - 예시로 NanoGPT의 forward 메서드를 보여주는데, 기존에는 모든 변수가 그냥 Tensor로만 보이던 것이 `Tensor[B, T, NEmbedding]`, `Tensor[T]`처럼 구체적으로 표시됨
  - 로컬 변수마다 직접 annotation을 달지 않아도, 함수와 클래스 경계의 일부 annotation에서 shape를 추론하는 방향임

- 이 기능이 필요한 이유는 딥러닝 코드에서 shape 버그가 너무 잘 숨어 있기 때문임
  - 어떤 연산자는 shape가 틀리면 바로 터지지만, 어떤 경우는 브로드캐스팅 때문에 조용히 잘못된 결과를 만들 수 있음
  - `print(x.shape)`를 여기저기 박거나 디버거로 따라가는 방식은 모델이 커질수록 피곤해짐
  - Pyrefly는 “실행 전에 shape 흐름을 계속 보여주자”는 쪽에 베팅함

> [!IMPORTANT]
> Pyrefly의 방향은 강력한 constraint solver로 모든 shape 오류를 증명하는 게 아니라, 적은 annotation으로 최대한 많은 shape 힌트를 추론해 개발 중 피드백을 주는 쪽임.

## 핵심 구현 아이디어는 두 가지

- 첫 번째는 타입 시스템 안에서 심볼릭 정수 연산을 지원하는 것임
  - `Tensor[3, 4]`는 shape가 `(3, 4)`인 2D 텐서를 뜻함
  - `nn.Linear[3, 4]`는 `Tensor[..., 3]`을 받아 `Tensor[..., 4]`를 반환하는 모듈로 표현됨
  - `Dim[X]`는 런타임 정수 값을 타입 수준으로 연결하는 다리 역할을 함

- 타입 수준 산술도 가능함
  - `a: Dim[3]`, `b: Dim[4]`라면 `a * b`는 `Dim[12]`로 추론됨
  - 예시에서는 `custom_rand_tensor[A, B](a: Dim[A], b: Dim[B]) -> Tensor[(A + B) // 2]` 같은 식도 가능함
  - 즉 shape가 고정 숫자가 아니라 `N`, `M`, `A + B`, `D // NHead` 같은 식으로 움직일 수 있음

- shape 다형적인 모듈도 표현할 수 있음
  - 예를 들어 `Linear[N, M]`은 마지막 차원이 `N`인 입력을 받아 마지막 차원을 `M`으로 바꾸는 모듈로 쓸 수 있음
  - `forward[*Xs](inp: Tensor[*Xs, N]) -> Tensor[*Xs, M]`처럼 앞쪽 batch 차원들은 그대로 보존하고 마지막 차원만 바꾸는 식임
  - 이게 실제 모델 코드에서 중요한 이유는 batch, sequence length, embedding size가 여러 레이어를 지나며 계속 이어지기 때문임

- 두 번째는 PyTorch 연산자별 shape 변환 명세임
  - `torch.mm`처럼 단순한 연산은 `Tensor[M, K]`와 `Tensor[K, N]`을 받아 `Tensor[M, N]`을 반환한다고 타입 stub으로 표현 가능함
  - `reshape`, `cat`, `F.interpolate`처럼 로직이 복잡한 연산은 작은 DSL로 shape 변환 규칙을 적음
  - 새 PyTorch 연산자의 shape 지원을 추가할 때 Pyrefly 내부를 직접 고치지 않고, 라이브러리 명세를 확장하는 구조를 노림

```mermaid
sequenceDiagram
    participant 개발자
    participant 에디터
    participant Pyrefly
    participant 타입시스템
    participant PyTorch명세
    개발자->>에디터: PyTorch 모델 코드 작성
    에디터->>Pyrefly: 타입 분석 요청
    Pyrefly->>타입시스템: Tensor 차원과 Dim 산술 추론
    Pyrefly->>PyTorch명세: 연산자별 shape 변환 규칙 조회
    Pyrefly-->>에디터: 중간 Tensor shape 인레이 힌트 표시
    Pyrefly-->>개발자: shape 불일치 오류 표시
```

## 기존 접근과의 차이

- Pyre와 Pyright 쪽 기존 시도는 문법이 무거웠다고 설명함
  - 예전 Pyre 접근은 `Literal`과 `IntDiv` 같은 타입 생성자를 써서 shape 산술을 표현했음
  - 예시로 `Tensor[float, M, Literal[2], IntDiv[M, 2]]` 같은 형태가 나오는데, 실제 모델 코드에 계속 쓰기엔 부담이 큼
  - Pyrefly는 사용자에게 노출되는 타입 기능은 단순하게 두고, 복잡한 연산자 shape 로직은 DSL 쪽으로 밀어낸다고 설명함

- jaxtyping과도 포지션이 다름
  - jaxtyping은 JAX, NumPy, PyTorch 같은 array-like 컨테이너를 폭넓게 지원하고, typeguard나 beartype과 함께 런타임에서 검사함
  - 문법은 `Shaped[Tensor, "M 2 M//2"]`처럼 문자열 기반이라 범용적이지만 다소 장황함
  - Pyrefly는 jaxtyping annotation도 대체 front-end로 받아 내부적으로 generic 기반 표현으로 변환할 수 있다고 함

- Pyrefly가 jaxtyping 대비 강조하는 한계는 “클래스 전체에서 symbolic dimension 공유”임
  - jaxtyping만으로는 클래스 안 여러 변수와 함수, 모듈 계층 사이에서 같은 symbolic dimension을 일관되게 공유하기 어렵다고 설명함
  - 그래서 NanoGPT처럼 여러 모듈이 연결된 실제 모델 전체를 end-to-end로 타입체크하기는 Pyrefly 방식이 더 적합하다는 주장임

## 아직은 실험 기능

- 문서 첫 줄부터 API와 동작이 바뀔 수 있다고 못 박음
  - 즉 지금 당장 프로덕션 타입 정책으로 강제하기보다는, 모델 개발 중 shape 힌트와 오류 탐지 보조 도구로 보는 게 맞음
  - PyTorch 연산자 커버리지는 계속 확장해야 하고, 특히 복잡한 operator의 shape DSL 명세가 많이 필요함
  - Pyrefly 팀도 fixture stub과 PyTorch operator DSL specification 기여를 요청하고 있음

---
## 기술 맥락

- Pyrefly의 선택은 텐서 shape를 일반 타입 annotation의 확장으로 다루는 거예요. 딥러닝 모델은 값의 타입보다 차원의 흐름이 더 자주 문제를 만들기 때문에, `Tensor[B, T, C]` 같은 정보를 정적 분석기가 이해하게 만드는 게 핵심이에요.

- 여기서 중요한 건 constraint solver를 크게 키우지 않았다는 점이에요. 모든 가능한 shape 관계를 엄밀하게 증명하려고 하면 annotation이 늘고 오류 메시지도 복잡해지거든요. Pyrefly는 대신 로컬 변수 shape를 많이 추론해서 에디터 힌트로 보여주는 실용적인 쪽을 택했어요.

- 복잡한 PyTorch 연산은 별도 DSL로 빼는 것도 유지보수 관점에서 의미가 있어요. `mm`처럼 간단한 건 타입 stub으로 충분하지만, `reshape`나 `cat`은 규칙이 훨씬 복잡해요. 이걸 타입 시스템 본체에 다 넣으면 도구 자체가 무거워지니, 라이브러리 명세로 확장 가능하게 만든 거예요.

- jaxtyping과 비교되는 지점은 검사 시점과 범위예요. jaxtyping은 런타임 검사와 범용 array 생태계에 강하고, Pyrefly는 정적 분석과 PyTorch 모델 계층 전체의 shape 흐름에 더 집중해요. 모델을 실행하기 전에 에디터에서 shape를 보고 싶은 팀이라면 Pyrefly 쪽 접근이 꽤 매력적일 수 있어요.

## 핵심 포인트

- Pyrefly는 PyTorch 코드의 텐서 shape를 실행 없이 추론해 에디터에 표시함
- 핵심은 타입 시스템 안의 심볼릭 정수 연산과 PyTorch 연산자별 shape 변환 명세
- Tensor[B, C, H, W], Dim[N], Tensor[*Xs, N] 같은 문법으로 shape 다형성을 표현함
- jaxtyping보다 클래스와 모듈 계층 전체에서 symbolic dimension을 공유하기 쉽다는 점을 강조함

## 인사이트

딥러닝 코드에서 shape 버그는 디버깅 시간이 정말 잘 녹는 영역이다. Pyrefly의 접근은 “더 많은 annotation으로 더 많은 빨간 줄”보다 “적은 annotation으로 중간 shape를 최대한 보여주자”에 가깝다는 점이 실용적으로 보인다.
