Developer Tools & Techniques

Achieve CUTLASS C++ Performance with Python APIs Using CuTe DSL

CuTe, a core component of CUTLASS 3.x, provides a unified algebra for describing data layouts and thread mappings, and abstracts complex memory access patterns into composable mathematical operations. 

While CUTLASS 3.x and CuTe have empowered kernel developers to achieve peak performance on Tensor Cores through intuitive abstractions, the extensive use of C++ templates has resulted in high compilation times. Additionally, the growing adoption of Python and just-in-time (JIT) compilation in both research and production generative AI workflows has driven the evolution and development of CUTLASS 4.

This post explains the advantages of using CuTe DSL. We show that it offers a consistent API with C++,  similar Tensor Core efficiency across different GPU chips, and much shorter compilation costs over C++.

For more information about the fundamentals of CuTe and CUTLASS 3.x, see CUTLASS: Principled Abstractions for Handling Multidimensional Data Through Tensors and Spatial Microkernels and CUTLASS 3.x: Orthogonal, Reusable, and Composable Abstractions for GEMM Kernel Design.

CuTe DSL: The foundation of CUTLASS 4

The new CuTe DSL (in Beta) in CUTLASS 4 brings the power of CuTe to Python programmers, allowing low-level GPU kernel authoring without the hassle of C++ template metaprogramming.

To simplify the learning curve associated with the new DSL, CuTe DSL relies on the same fundamental concepts underpinning CuTe. Visit NVIDIA/cutlass on GitHub to see a few CuTe DSL examples, including the persistent variant of dense GEMMgrouped GEMM, and Fused Multi-Head Attention (FMHA)

Comparing CuTe DSL and CuTe C++

CuTe offers a consistent GPU programming model across more than a decade of NVIDIA GPU architectures through its robust layout representation and algebra. CuTe DSL retains the exact same programming model users have come to expect from CuTe C++ but with the ease of Python. With this comes blazing fast compile times, substantially improved error messages, flatter learning curve, and near-instant integration into Python native DL frameworks.

A side-by-side comparison of C++ and DSL code highlights how they have identical programming models and programming patterns. The only differences are in the C++ and Python language syntax.

TiledMMA

cute::TiledMma is a spatial microkernel that describes the tiling and permutations of any hardware MMA atom across a set of “threads” and data. Its representation enables writing canonical triple for loops for any hardware MMA, be it SIMT FP64 or the cutting-edge NVFP4 Blackwell tensor core instructions.

auto tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, TC,
                                         128, 128,
                                         UMMA::Major::MN, UMMA::Major::MN>{},
                           Layout<Shape<_1,_1>>{});
 
// Allocate "fragments" -- these are actually umma tmem and smem descriptors
 Tensor tCrA = tiled_mma.make_fragment_A(sA);  // (MMA,MMA_M,MMA_K,PIPE)
 Tensor tCrB = tiled_mma.make_fragment_B(sB);  // (MMA,MMA_M,MMA_K,PIPE)
  
 // Allocate TMEM
 Tensor tCtC = tiled_mma.make_fragment_C(tCgC);// (MMA,MMA_M,MMA_N)
 for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
   static_assert(size<2>(tCrA) == size<2>(tCrB), "A and B contraction modes do not match!");
   gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtC)
 }

# Construct a tiled_mma item
atom = tcgen05.MmaF16BF16Op(
        io_dtype,
        acc_dtype,
        mma_inst_shape_mnk, #(128, 128, 64)
        tcgen05.CtaGroup.ONE,
        tcgen05.OperandSource.SMEM,
        tcgen05.OperandMajorMode.K,
        tcgen05.OperandMajorMode.K,
    )
tiled_mma = cute.make_tiled_mma(atom)

tCrA = tiled_mma.make_fragment_A(sA)   # (MMA, MMA_M, MMA_K,PIPE)
tCrB = tiled_mma.make_fragment_B(sB)   # (MMA, MMA_N, MMA_K,PIPE)
tCtC = tiled_mma.make_fragemnt_C(tCgC) # (MMA, MMA_M, MMA_N)

for k_block_idx in cute.size(tCrA, mode = 2):
   assert(cute.size(tCrA, mode = 2) == cute.size(tCrB, mode = 2), "A and B contraction modes do not match!");
    cute.gemm(
        tiled_mma, tCtC, tCrA[None, None, k_block_idx], tCrB[None, None, k_block_idx], tCtC)

TiledCopy

A canonical cute::copy is a single loop issuing some data movement instruction to copy one tensor to another, using the layouts of the tensors to describe any transposes or permutations that may happen along the way. cute::TiledCopy is a type used to represent and verify the applicability of optimized transfers of data between any two tensors. 

For example, in different memory spaces such as global to shared memory or within memory, with or without incorporating layout transformations (transpose), using any hardware accelerated copy atom.

using TMEM_LOAD = typename std::conditional<sizeof(TC) == 4, SM100_TMEM_LOAD_16dp256b1x, SM100_TMEM_LOAD_16dp256b1x_16b>::type;
// tCtC are accumuator layout
 auto tiled_ldtm = make_tmem_copy(TMEM_LOAD{}, tCtC);
 auto thr_ldtm   = tiled_ldtm.get_slice(threadIdx.x);
 Tensor tDtC = thr_ldtm.partition_S(tCtC);  // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
 Tensor tDgC = thr_ldtm.partition_D(tCgC);  // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
 Tensor tDrC = make_tensor<TC>(shape(tDgC));// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
// TMEM_LOAD
copy(tiled_ldtm, tDtC, tDrC);
# Construct a tensor memory to register memory (T2R) tiled_copy item
# tCtACC are accumulator tensor, layout as (MMA, MMA_M, MMA_N)
# tCgC is the partitioned results (MMA, MMA_M, MMA_N, RestM, RestN, RestL) of global tensor C (M, N)
copy_atom = cute.make_copy_atom(
  tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
  cutlass.Float32)
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom, tCtACC)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# This is tensor memory layout (T2R_M, T2R_N, EPI_M, EPI_N)
tT2R_tAcc = thr_copy_t2r.partition_S(tCtACC)
# (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
tT2R_gC = thr_copy_t2r.partition_D(tCgC)
# Construct register memory layout from the partitioned global tensor 
tT2R_rAcc = cute.make_fragment(
   tT2R_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32)
cute.copy(tiled_copy_t2r, tT2R_tAcc, tT2R_rAcc)

CuTe DSL performance across multiple GPU generations

One of key factors that has driven the adoption of CUTLASS C++ in training and inference frameworks is its ability to deliver blazing fast performance. CuTe DSL delivers nearly the same level of performance, and more optimizations are in the pipeline. 

Additionally, CUTLASS 3 and the underlying CuTe have been deployed in research and production use cases on the last few generations of GPU hardware. The deployed GPU hardware has a long shelf life in production environments, sometimes in heterogeneous settings. CuTe DSL, at its launch, supports NVIDIA GPU generations from Ampere to Blackwell, to support these deployments.

NVIDIA Blackwell performance 

We measured the performance of three key operations: dense GEMM, grouped GEMM, and FMHA of both CUTLASS C++ and CuTe DSL. Overall, CuTe DSL performance is similar to CUTLASS C++.

Dense GEMM 

We measured the performance of dense GEMM in two precision settings, float16 and float8 e4m3. Both types use float32 as the accumulation precision.

Figure 1 shows the comparative benchmarking on NVIDIA DGX B200 with CuTe DSL dense GEMM and CUTLASS 3 dense GEMM from the NVIDIA/cutlass GitHub repo. The x-axis shows the tested problem sizes, and the y-axis represents Tensor Core math throughput efficiency captured through NVIDIA Compute Nsight.

For small GEMM-K problem sizes (K=512), the DSL kernel currently performs slower than C++. This is due to inefficient synchronization costs before entering the math computation of the kernel, which the team is actively working to optimize.

Bar chart titled ‘B200 Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++ except for the small K cases.
Figure 1. Comparative benchmarking on NVIDIA DGX B200 with CuTe DSL dense GEMM and CUTLASS 3 dense GEMM

Grouped GEMM 

Comparative benchmarking uses CuTe DSL grouped GEMM and CUTLASS 3 grouped GEMM from the NVIDIA/cutlass GitHub repo.

Bar chart titled ‘B200 Float16 I/O Group GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
Figure 2. Comparative benchmarking on NVIDIA DGX B200 with Float16 I/O CuTe DSL Group GEMM and CUTLASS 3 Group GEMM

Fused Multi-Head Attention (FMHA)

Comparative benchmarking uses CuTe DSL FMHA and CUTLASS 3 FMHA from the NVIDIA/cutlass GitHub repo. 

Bar chart titled ‘B200 Float16 I/O Flash Attention Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
Figure 3. Comparative benchmarking on NVIDIA DGX B200 with Float16 I/O CuTe DSL Flash Attention and CUTLASS 3 Flash Attention

Ampere performance: Dense GEMM

Comparative benchmarking uses CuTe DSL dense GEMM (Ampere) and CUTLASS 3 dense GEMM (Ampere) from the NVIDIA/cutlass GitHub repo.

Bar chart titled ‘A100 FP16 I/O Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is slightly slower than C++ and perf gaps to be investigated.
Figure 4. Comparative benchmarking on NVIDIA A100 with Float16 I/O CuTe DSL dense GEMM and CUTLASS 3 dense GEMM

Reduction in compilation time 

CuTe DSL offers kernel developers the ability to JIT kernels using CuTe abstractions, overcoming the high compilation time of C++ templates.

As shown in Figure 5, the compilation time reduction is remarkable, on average, up to two orders of magnitude reduced. It not only enables kernel developers to exercise more tile sizes and layout shapes to identify the right config quickly to extract very fast performance, but it also could reduce the total time of autotuning feature of PyTorch Inductor.

GEMM on Blackwell achieves ~100x compilation speedup over C++, while flash attention on Blackwell delivers compilation speedups of 30-50x.

Bar chart showing compilation time on B200 is far faster than C++ with speedups of 33 -116x.
Figure 5. Compilation time on NVIDIA Blackwell is much faster than C++

Easy DL framework integration

With the support of DLPack protocol, CuTe DSL is capable of taking popular deep learning framework tensor data as input directly and converting it into cute.Tensor without replicating the underlying memory.

The CuTe DSL Python-native interfaces allow deep learning frameworks to embed customized kernels directly without requiring cumbersome glue code or deep expertise in CUDA C++. This accelerates development cycles by enabling researchers and engineers to prototype and deploy custom linear algebra kernels rapidly within their existing model pipelines. 

The DSL composable layout abstractions simplify expressing complex memory and thread mappings, which are critical for exploiting Tensor Core hardware efficiently across NVIDIA Ampere, Hopper, and Blackwell architectures.

Get started with CuTe DSL

CuTe DSL introduces a new programming interface to improve developer velocity while retaining the performance of CUTLASS C++. Check out the Quick Start Guide to learn more about building performant kernels. You can help expand the suite of examples by contributing those kernels to the CUTLASS GitHub

To get started, download CUTLASS and read the CUTLASS documentation. Join the NVIDIA Developer Forum for deeper discussions.

Acknowledgments

We would like to express our gratitude to all of the CUTLASS OSS contributors. Without their foundational contributions, CUTLASS 4 would not have been possible.

Discuss (0)

Tags