NVIDIA CUDA Tile은 NVIDIA 텐서 코어의 이식성을 극대화해 GPU 성능을 정점까지 끌어올리는 프로그래밍 모델입니다. 특히 CUDA Tile 위에 개발자만의 독자적인 도메인 특화 언어(DSL)를 구축할 수 있다는 점이 가장 큰 경쟁력입니다.
이번 포스팅에서는 GPU 딥러닝 커널 작성을 위한 오픈 소스 파이썬 DSL, OpenAI Triton의 백엔드로 CUDA Tile을 통합한 최신 개발 성과를 공유합니다. OpenAI Triton은 연산을 작은 블록 단위로 나누는 타일형 연산 기법을 활용하며, PTX를 생성하는 MLIR 기반 컴파일러를 갖추고 있습니다. 덕분에 CUDA 숙련도가 높지 않은 연구자도 효율적인 GPU 코드를 직접 작성할 수 있습니다.
CUDA Tile 및 CUDA Tile IR이란 무엇인가요?
CUDA Tile은 타일 중심 프로그래밍을 기본으로 지원하기 위해 기존 CUDA 모델을 한 단계 확장한 기술입니다. CUDA 13.1부터 도입된 이 모델은 GPU 프로그래밍 패러다임에서 중요한 변화를 가져왔습니다. 이제 개발자들은 SIMT 모델의 개별 스레드 단위 연산에 얽매일 필요없이 타일이라는 더 높은 수준의 추상화 모델을 통해 연산을 직관적으로 표현하고 하드웨어 성능을 최대로 끌어낼 수 있습니다.
개발자가 데이터 블록(타일) 연산을 지정하기만 하면, 컴파일러와 런타임 시스템이 스레드 스케줄링부터 하드웨어 매핑, 리소스 할당까지 자동으로 처리합니다. 이러한 설계 덕분에 프로그래밍 복잡도는 획기적으로 낮아졌으며, 컴파일러는 더욱 강력한 최적화를 수행할 수 있게 되었습니다.
CUDA Tile IR은 MLIR을 기반으로 하는 중간 표현(Intermediate Representation)이자 컴파일러 인프라입니다. NVIDIA GPU에서의 타일 기반 연산을 위해 공식 시맨틱, 오퍼레이션, 타입 시스템을 정의하는 ‘CUDA Tile IR 사양’이 이 기술의 핵심 축을 담당하고 있습니다.
Triton-to-TileIR이란 무엇인가요?
Triton-to-TileIR 백엔드는 Triton이 기존 PTX 대신 CUDA Tile IR을 타겟팅하도록 돕는 가교 역할을 합니다. Triton 컴파일러 생태계를 확장하여, 개발자가 OpenAI Triton으로 작성한 GPU 커널을 새로 도입된 CUDA Tile IR 백엔드에서 그대로 컴파일하고 실행할 수 있게 해줍니다. 이로써 고수준 언어인 Triton과 NVIDIA의 차세대 프로그래밍 모델이 하나로 연결되어, 개발자는 별도의 코드 수정 없이도 최신 하드웨어의 성능을 온전히 활용할 수 있습니다.
GPU 프로그래밍이 기존 SIMT 모델에서 타일 기반 추상화로 진화함에 따라, 이번 통합이 주는 가치는 더욱 분명해졌습니다. 개발자들은 Triton의 편리한 파이썬 문법을 유지하면서도 텐서 코어에 대한 TileIR의 네이티브 지원과 강력한 아키텍처 이식성을 동시에 확보하게 됩니다.
Triton-to-TileIR은 차세대 기술로 향하는 진입 장벽을 획기적으로 낮춥니다. Triton 자체가 데이터를 블록(타일) 단위로 처리하는 언어인 만큼, 개념적으로 CUDA Tile IR과 완벽히 일치하기 때문입니다. 이러한 기술적 일관성 덕분에 백엔드 컴파일 과정 역시 훨씬 직관적으로 변했습니다. 타일 수준의 추상화를 굳이 하위 수준인 스레드 단위(SIMT)로 쪼개어 번역하는 대신, 타일 단위의 의미론을 그대로 유지하며 CUDA Tile IR로 직접 컴파일하는 경로를 구축했습니다.
기존 Triton 사용자들은 새로운 언어를 배우거나 코드를 재작성할 필요가 없습니다. 간단한 환경 변수 설정만으로 컴파일 파이프라인을 PTX에서 CUDA Tile IR 백엔드로 즉시 전환할 수 있으며, 이를 통해 성능 향상은 물론 미래 아키텍처에 대한 호환성까지 선제적으로 확보할 수 있습니다.
나아가 개발자들은 자신의 애플리케이션 내에서 각 커널의 특성에 따라 최적의 백엔드(PTX 또는 CUDA Tile IR)를 자유롭게 선택해 사용할 수 있게 됩니다.
Triton-to-TileIR 개발 로드맵
현재 Triton-to-TileIR은 triton-lang 조직 내에서 인큐베이터 프로젝트로 활발히 개발 중입니다. 이 리포지토리는 메인 Triton 컴파일러에 정식 통합되기에 앞서, CUDA Tile IR 백엔드를 직접 구현하고 정교하게 다듬어 나가는 핵심적인 협업 공간입니다.
개발 로드맵은 다음과 같은 주요 기술 워크스트림으로 구성됩니다.
- 핵심 변환 인프라: Triton 연산을 그에 대응하는 CUDA Tile IR로 매핑하기 위한 MLIR 다이얼렉트(Dialect) 변환 패턴 구현
- 테스트 및 검증: 제어 흐름, 메모리 액세스 패턴, 수치 정밀도의 에지 케이스를 포함하여 변환의 시맨틱 정확성을 확인하기 위한 종합 테스트 스위트 개발
- 성능 벤치마킹: 다양한 연산(행렬 곱셈, 컨볼루션, 요소별 연산, 리덕션 등)에 대해 TileIR 컴파일 커널과 기존 PTX 컴파일 커널의 성능을 비교하는 기준점 수립
- 오픈 소스 프로젝트 통합: Helion과 같은 프로젝트에서 CUDA Tile IR 백엔드를 원활히 지원할 수 있도록 오픈 소스 커뮤니티와 협력 유도
Triton-to-TileIR 사용 방법
Triton-to-TileIR은 현재 소스 코드 기반 컴파일만 지원합니다. 별도의 빌드된 바이너리가 제공되지 않으므로, 로컬 환경에서 직접 프로젝트를 빌드해야 합니다.
사전 요구 사항:
- CUDA 버전: CUDA 13.1 이상
- GPU 아키텍처: NVIDIA Blackwell GPU(예: GeForce RTX 5080). 이전 세대 GPU 아키텍처는 향후 CUDA 릴리스에서 지원될 예정입니다.
소스로부터 구축
사전 요구 사항이 충족되면 리포지토리를 클론하고 빌드를 진행합니다.
# Clone the repository
git clone https://github.com/triton-lang/Triton-to-tile-IR.git
cd Triton-to-tile-IR
# Build and install
# Specific build instructions should be followed according to the project's README
pip install -e .
주의: 상세 구축 단계는 환경에 따라 다를 수 있습니다. 아키텍처별 설정, 의존성 관리 및 트러블슈팅은 Triton-to-TileIR README와 빌드 문서를 확인하시기 바랍니다.
Tile IR 컴파일 확인
빌드 후에는 벡터 덧셈 튜토리얼을 실행하여 Tile IR 백엔드가 정상적으로 작동하는지 확인합니다.
# Navigate to the tutorial directory
cd python/tutorials
# Run the vector addition example with Tile IR enabled
export ENABLE_TILE=1
python 01-vector-add.py
Tile IR 백엔드가 활성화되면, Triton은 SIMT 백엔드에서 사용하는 표준 .cubin 파일 대신 .tileIR 확장자를 가진 파일로 컴파일된 커널을 캐싱합니다. 다음 경로에서 해당 캐시 파일들이 생성되었는지 확인하여 백엔드 적용 여부를 알 수 있습니다:
# Find the Triton cache directory (typically in ~/.triton/cache)
Triton-to-TileIR의 제약 사항
Triton-to-TileIR은 흥미로운 가능성을 열어주지만, 아직 개발 초기 단계인 만큼 미지원 연산이나 일시적인 성능 이슈와 같은 몇 가지 제약이 존재합니다.
미지원 연산
현재 Triton이 지원하는 모든 연산이 Tile IR 백엔드에 구현된 것은 아닙니다. 아직 지원되지 않거나 일부만 지원되는 기능 및 연산에 대해 더 알아보세요.
향후 CUDA의 새 버전이 출시됨에 따라 Triton CUDA Tile IR 백엔드의 호환성도 지속적으로 개선될 예정입니다.
포인터 텐서(Tensor-of-pointer) 사용 시의 성능 저하
Triton의 포인터 텐서 패턴은 텐서가 메모리 액세스 패턴을 기술하는 포인터 집합으로 이루어집니다. 이 방식은 CUDA 13.1 기반의 Tile IR 백엔드에서 다소 아쉬운 성능을 보이고 있습니다. 이는 일시적인 현상으로, 해당 패턴을 사용하는 워크로드에 대해서는 다음과 같은 대응 방안을 제안합니다.
- 특정 핵심 연산에 대해 일시적으로 SIMT 백엔드 사용
- 향후 릴리스될 최적화 패스 적용 대기
- TMA load/store API를 사용하도록 코드 수정
앞서 언급한 TMA API를 활용해야 하는 이유는 명확합니다. 커널에서 로드하는 텐서 대부분은 이미 연속적인 타일 구조와 정교한 형태, 스트라이드를 갖추고 있습니다. 따라서 커널 내부에서 별도로 포인터 텐서를 생성할 필요가 없습니다. 대신 이러한 레이아웃 정보를 TMA load/store API에 직접 전달하면, Tile IR 백엔드가 하드웨어 역량을 십분 활용하여 성능을 비약적으로 향상시킬 수 있습니다.
일반적인 포인터 텐서 패턴의 예시는 다음과 같습니다.
# Before: tensor-of-pointer style
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am
+ offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk
+ offs_n[None, :] * stride_bn)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
여기서 a_ptrs의 모든 요소는 커널 내부에서 직접 계산한 명시적 포인터들입니다. 타일 자체가 연속적인 구조를 띠고 있어 shape, strides, block_shape만으로도 그 레이아웃을 충분히 정의할 수 있다는 점을 고려하면, 이는 상당히 번거롭고 비효율적인 방식입니다.
TMA를 사용하면 동일한 연산을 다음과 같이 재작성할 수 있습니다.
desc_a = tl.make_tensor_descriptor(
a, # base pointer
shape=(M, K),
strides=(stride_am, stride_ak),
block_shape=(BLOCK_M, BLOCK_K) # tile size
)
desc_b = tl.make_tensor_descriptor(
b, shape=(K, N),
strides=(stride_bk, stride_bn),
block_shape=(BLOCK_K, BLOCK_N)
)
offs_m = pid_m * BLOCK_M
offs_n = pid_n * BLOCK_N
a_tile = desc_a.load([offs_m, 0]) # [BLOCK_M, BLOCK_K]
b_tile = desc_b.load([0, offs_n]) # [BLOCK_K, BLOCK_N]
desc_c.store([offs_m, offs_n], acc) # TMA-backed store
더 알아보기
Triton-to-TileIR 프로젝트는 개발자의 생산성과 하드웨어 효율성 사이의 간극을 메우는 GPU 프로그래밍 진화의 중요한 성과입니다. 사용하기 쉬운 Triton의 타일 중심 모델이 CUDA Tile IR 가상 명령어 세트(vISA)를 지원함에 따라, 머신러닝 전문가와 GPU 개발자 모두 성능과 이식성, 미래 경쟁력을 동시에 확보하게 되었습니다.
기존 Triton 사용자들에게 TileIR 백엔드는 코드 수정을 최소화하면서도 차세대 GPU 아키텍처의 잠재력을 극대화할 수 있는 최적의 방법을 제시합니다. 또한, 이번 협업은 언어 설계자와 하드웨어 제조사 간의 전략적 파트너십이 어떤 시너지를 창출할 수 있는지 보여주는 모범 사례이기도 합니다. 결과적으로 빠른 혁신을 뒷받침하는 고수준 추상화 모델을 유지하면서, 최첨단 하드웨어가 제공하는 강력한 성능을 별도의 진입 장벽 없이 온전히 활용할 수 있는 길이 열린 것입니다.
이 프로젝트가 인큐베이팅 단계를 넘어 실제 서비스 환경에 적용 가능한 수준으로 성숙해지면, Triton의 대중화는 물론 타일 기반 GPU 프로그래밍의 패러다임을 바꾸는 중요한 전환점이 될 것입니다. NVIDIA가 지향하는 성공의 기준은 명확합니다. GPU 전문 지식이 깊지 않은 연구자라도, NVIDIA GPU에서 최적의 성능을 발휘하는 Triton 코드를 손쉽게 작성할 수 있는 환경을 구축하는 것입니다.
더 자세한 내용은 triton-lang/Triton-to-tile-IR GitHub 리포지토리와 CUDA Tile IR 백엔드 성능 튜닝 팁에서 확인하실 수 있습니다.