Models / Libraries / Frameworks

Unlock GPU Performance: Global Memory Access in CUDA

Managing memory is one of the most important performance characteristics to consider when writing a GPU kernel.  This post walks you through the important aspects you should know about global memory and its performance.

Global Memory

There are several kinds of memory on a CUDA device, each with different scope, lifetime, and caching behavior. Global memory (also called device memory) is the primary memory space on CUDA devices. It resides in device DRAM and functions similarly to RAM in CPU systems. The term “global” refers to its scope, and it can be accessed and modified by both the host and all threads within a kernel grid. 

Global memory can be statically declared using the __device__ declaration specifier at global scope, or dynamically allocated using CUDA runtime APIs such as cudaMalloc() or cudaMallocManaged(). Data can be transferred from host to device using cudaMemcpy() and deallocated with cudaFree(). These allocations are persistent until freed.  

Global memory can also be allocated/freed via the use of Unified Memory.  The topic of global memory allocation/deallocation and movement to/from the device has complexities that’ll be covered in a future post.  For this post, we’ll focus on the performance implications of using global memory in CUDA kernels.

A simple example of a typical usage pattern involves the host allocating and initializing global memory before kernel launch, followed by kernel execution where CUDA threads read from and write results back to global memory, and finally host retrieval of results after kernel completion.

Example: Dynamic Allocation, Transfer, Kernel, and Cleanup

// Host allocates global memory
float* d_input;
float* d_output;
cudaMalloc(&d_input, n * sizeof(float));
cudaMalloc(&d_output, n * sizeof(float));

// Transfer data to device
cudaMemcpy(d_input, h_input, n * sizeof(float), cudaMemcpyHostToDevice);

// Call a kernel to operate on the device
someKernel<<<1024, 1024>>>(d_input, d_output, n);

// Copy the result back to the host
cudaMemcpy(h_output, d_output, n * sizeof(float), cudaMemcpyDeviceToHost);

// Cleanup
cudaFree(d_input);
cudaFree(d_output);

Global Memory Coalescing

Before we go into global memory access performance, we need to refine our understanding of the CUDA execution model. We have discussed how threads are grouped into thread blocks, which are assigned to multiprocessors on the device. During execution, there’s a finer grouping of threads into warps. Multiprocessors on the GPU execute instructions for each warp in SIMT (Single Instruction Multiple Threads) fashion. The warp size (effectively the SIMT width) of all current CUDA-capable GPUs is 32 threads.

A crucial aspect for you to consider when accessing global memory in CUDA is how memory locations accessed by different threads within the same warp are related. The pattern of these memory accesses directly affects memory access efficiency and overall application performance.  

Global memory is accessed via 32-byte memory transactions. When a CUDA thread requests data from global memory, memory accesses from all threads in that warp are coalesced into a minimum number of memory transactions. The number of memory transactions required depends on the size of the word accessed by each thread and the distribution of the memory addresses across the threads.

The following code demonstrates a scenario in which consecutive threads within a warp access consecutive 4-byte data elements, creating an optimal memory access pattern. All the loads issued by a warp can be satisfied by four 32-byte sectors from memory, which allows for the most efficient use of memory bandwidth. Figure 1 shows how each thread accesses a 4-byte element of data in contiguous memory.

__global__ void coalesced_access(float* input, float* output, int n) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < n) {
        // Each thread accesses consecutive 4-byte words
        output[tid] = input[tid] * 2.0f ;
    }
}
Coalesced Memory access pattern showing the threads of a warp accessing a contiguous 128-byte memory chunk in four 32-byte sectors.
Figure 1. Coalesced Memory access pattern showing the threads (arrows) of a warp accessing a contiguous 128-byte memory chunk in four 32-byte sectors.

Conversely, if threads access memory with large strides, each memory transaction fetches much more data than is‌ needed. For each 4-byte element that each thread requests, an entire 32-byte segment is fetched from global memory, with most of the data transfer unused. Figure 2 shows an example of this pattern.

__global__ void uncoalesced_access(float* input, float* output, int n) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < n) {
        // Access with a stride of 32 (128 bytes), wrapped around to stay within bounds
        int scattered_index = (tid * 32) % n;
        output[tid] = input[scattered_index] * 2.0f;
Uncoalesced Memory access pattern showing each thread (down arrow) accessing data in a separate 32-byte memory sector.
Figure 2. Uncoalesced Memory access pattern showing each thread (arrow) accessing data in a separate 32-byte memory sector.

Let’s dive into analyzing memory access patterns of these two contrasting CUDA kernels using NVIDIA Nsight Compute (NCU). NCU provides powerful metrics to quantify memory access patterns. 

To begin profiling a kernel, we typically run:

ncu --set full --print-details=all ./a.out

This command collects all available profiling sections, including memory, instruction, launch, occupancy, cache, and more. However, when focusing specifically on memory access efficiency, we narrowed it down to metrics that quantify memory workload patterns. To isolate ‌details related only to memory workload, the following command is more appropriate:

ncu --section MemoryWorkloadAnalysis_Tables --print-details=all ./a.out

The output from this command is shown below, simplified for clarity.

 coalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9

  uncoalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Memory Workload Analysis Tables
    OPT   Est. Speedup: 83%
          The memory access pattern for global loads from DRAM might not be optimal. On average, only 4.0 of the 32
          bytes transmitted per sector are utilized by each thread. This applies to the 100.0% of sectors missed in
          L2. This could possibly be caused by a stride between threads. Check the Source Counters section for
          uncoalesced global loads.

From the output, we can see that NCU has identified an area of performance improvement in the “uncoalesced_access” kernel in terms of global loads, and in fact says we’re on average only utilizing 4 bytes of each 32-byte sector that is fetched. NCU even suggests that “this could be caused by a stride between the threads”.

We specifically set up the problem to illustrate both good and bad memory performance, so this isn’t surprising. To dig a bit further, we can look at what other kinds of memory analysis tables NCU can provide. 

Since the initial output of NCU identified issues with loads from DRAM, we’ll next try this command to dig deeper into the DRAM statistics.

ncu --metrics group:memory__dram_table ./a.out
 coalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Command line profiler metrics
    --------------------------------------------------- ----------- ------------
    Metric Name                                         Metric Unit Metric Value
    --------------------------------------------------- ----------- ------------
    dram__bytes_read.sum                                      Mbyte       268.44
    dram__bytes_read.sum.pct_of_peak_sustained_elapsed            %        46.76
    dram__bytes_read.sum.per_second                         Gbyte/s       159.76
    dram__bytes_write.sum                                     Mbyte       248.50
    dram__bytes_write.sum.pct_of_peak_sustained_elapsed           %        43.28
    dram__bytes_write.sum.per_second                        Gbyte/s       147.89
    dram__sectors_read.sum                                   sector    8,388,900
    dram__sectors_write.sum                                  sector    7,765,572
    --------------------------------------------------- ----------- ------------

  uncoalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Command line profiler metrics
    --------------------------------------------------- ----------- ------------
    Metric Name                                         Metric Unit Metric Value
    --------------------------------------------------- ----------- ------------
    dram__bytes_read.sum                                      Gbyte         2.15
    dram__bytes_read.sum.pct_of_peak_sustained_elapsed            %        84.92
    dram__bytes_read.sum.per_second                         Gbyte/s       290.16
    dram__bytes_write.sum                                     Mbyte       263.70
    dram__bytes_write.sum.pct_of_peak_sustained_elapsed           %        10.43
    dram__bytes_write.sum.per_second                        Gbyte/s        35.63
    dram__sectors_read.sum                                   sector   67,110,368
    dram__sectors_write.sum                                  sector    8,240,680
    --------------------------------------------------- ----------- ------------

With this result, we can see the massive difference in the output of dram__sectors_read.sum between the two kernels. Our kernels are reading an array and then writing back the same array, so we should have the same amount of data being read as being written, but in the uncoalesced case, we see an 8x difference between sectors_read versus sectors_write.  

Now let’s analyze the L1 behavior using this command:

ncu --metrics group:memory__first_level_cache_table ./a.out

This command outputs a lot of information which we’ve omitted here, but if you run it, the key is to notice the metrics that are different between the two kernels. There are two that we want to investigate further: l1tex_t_requests_pipe_lsu_mem_global_op_ld.sum and l1tex_t_sectors_pipe_lsu_mem_global_op_ld.sum.  NCU provides a table that helps you decode what information these metrics collect. The first metric is essentially the number of memory requests made, and the second metric is how many sectors were‌ fetched.  

When profiling GPU kernels for memory efficiency, sectors (32-byte chunks of data transferred from memory) and requests (memory transactions initiated by warps) provide valuable insights into memory coalescing behavior.  The ratio of sectors to requests provides a clear picture of how efficiently the code utilizes the memory system.

We can collect only these two metrics if we use the following command:

ncu --metrics l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum,l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum ./a.out

The output we obtain is:

coalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    ----------------------------------------------- ----------- ------------
    Metric Name                                     Metric Unit Metric Value
    ----------------------------------------------- ----------- ------------
    l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum                  2097152
    l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum       sector      8388608
    ----------------------------------------------- ----------- ------------

  uncoalesced_access(float *, float *, int) (262144, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    ----------------------------------------------- ----------- ------------
    Metric Name                                     Metric Unit Metric Value
    ----------------------------------------------- ----------- ------------
    l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum                  2097152
    l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum       sector     67108864

In the coalesced kernel, the ratio of requests to sectors is 1:4, which is what we’d expect. Recall Figure 1 where we showed a perfectly coalesced memory transaction of 128 bytes would require four 32-byte sectors. Every byte fetched from memory is used by the kernel, achieving 100% memory bandwidth efficiency.

In the uncoalesced kernel, the ratio of requests to sectors is 1:32, which also is what we’d expect as we recall Figure 2 where each thread is requesting 4 bytes from a different 32-byte sector. So each request from a warp is requesting 32 sectors. While the memory system fetches 32 sectors (1,024 bytes total), each thread only needs 4 bytes from its respective sector.

This 8x efficiency difference has profound implications for GPU performance, as memory bandwidth often determines the ultimate performance limits of GPU kernels. More information on profiling, including memory sectors can be found in the Profiling Guide section.

Strided Access

Now let’s look at the effect of strides on memory bandwidth. In the context of CUDA memory access patterns, stride refers to the distance (measured in array elements or bytes) between consecutive memory locations that threads of a warp access.

The results of the bandwidth measurements of kernels like those shown above with different access strides are shown in Figure 3. This isn’t intended to show the maximum bandwidth achievable, but simply how the bandwidth of a simple kernel changes when ‌access to global memory is strided.

Bandwidth versus stride on GH200 for strides from 0 to 31, showing decreasing values.
Figure 3. Bandwidth versus stride on GH200 for strides from 0 to 31, showing decreasing values.

The graph shows that effective bandwidth is poor for large strides, as expected. When threads of a warp access memory addresses that are far apart in physical memory, the hardware can’t combine these accesses efficiently.

Multidimensional Arrays

Now let’s talk about memory access in the case of multidimensional arrays, or matrices. To get the best performance and achieve coalesced memory access, it’s important for consecutive threads to access consecutive elements in the array, just like in the 1D case.

When using 2 or 3-dimensional thread blocks in a CUDA kernel, the threads are laid out linearly with the X index, or threadIdx.x, moving the fastest, then Y (threadIdx.y) and then Z (threadIdx.z). For example, if we have a 2D thread block with size (4,2), the threads will be ordered as: (0,0)(1,0)(2,0)(3,0)(1,0)(1,1)(2,1)(3,1).

It’s typical to use 2D thread blocks in CUDA when accessing 2D data such as a matrix.  When we consider accessing a matrix (stored as a 1D memory array) using a 2D thread block, since C++ stores 2D data in row-major form, row accesses are contiguous. If we can have consecutive threads access consecutive memory locations across a row, those accesses will be efficient (coalesced), while column access is inefficient (strided, non-coalesced). 

Since consecutive threadIdx.x values within a warp should access consecutive memory elements for coalescing, the threads with the same threadIdx.y value  should access a row of the matrix. This ensures that when threads in a warp access matrix elements, they follow the natural row-major memory layout, enabling efficient coalesced memory transactions and maximizing memory bandwidth utilization.

For the coalesced kernel (coalesced_matrix_access) that follows the memory access pattern results in efficient coalesced accesses because of how thread indices are mapped to matrix coordinates, given the row-major storage order. Here, the x-dimension of each block (threadIdx.x) is assigned to the column index, meaning that as consecutive threads within a warp increase their threadIdx.x, they access consecutive columns of the matrix while staying in the same row (Figure. 4). Since row-major order stores consecutive memory locations as elements within the same row, accessing across a row allows each thread in the warp to access memory locations that are contiguous. 

__global__ void coalesced_matrix_access(float* matrix, int width, int height)  
{  
    int row = blockIdx.y * blockDim.y + threadIdx.y;  
    int col = blockIdx.x * blockDim.x + threadIdx.x;  
    if (row < height && col < width) {  
        int idx = row * width + col;          // row-major ⇒ coalesced  
        matrix[idx] = matrix[idx] * 2.0f + 1.0f;  
    }  
}
Coalesced 2D access showing how 2D threadblocks map to the 2D matrix, and also how it maps to the linear memory where the matrix resides.  Consecutive threads access consecutive row elements which are contiguous in memory.
Figure 4. Coalesced 2D access showing how 2D threadblocks map to the 2D matrix, and also how it maps to the linear memory where the matrix resides. Consecutive threads access consecutive row elements which are contiguous in memory.

For the uncoalesced kernel (uncoalesced_matrix_access) shown next, the memory access pattern results in inefficient uncoalesced accesses. 

__global__ void uncoalesced_matrix_access(float* matrix, int width, int height)  
{  
    int row = blockIdx.y * blockDim.y + threadIdx.y;  
    int col = blockIdx.x * blockDim.x + threadIdx.x;  
    if (row < height && col < width) {  
        int idx = col * height + row;         // column-major ⇒ uncoalesced  
        matrix[idx] = matrix[idx] * 2.0f + 1.0f;  
    }  
}

Here, to illustrate the point, the kernel artificially treats the row-major matrix as if it were column-major by using the index calculation col * height + row. This means that as consecutive threads within a warp increase their threadIdx.x (incrementing the column index), they access elements that’d be consecutive in the column-major layout but are‌ strided in the row-major memory layout. Since the data is physically stored in row-major order but being accessed with column-major indexing, consecutive threads end up accessing memory locations that are spaced height elements apart, creating a large stride pattern that eliminates the GPU’s ability to coalesce these accesses into efficient transactions (Figure. 5). This mismatch between storage order and access pattern leads to poor global memory bandwidth utilization.

Uncoalesced 2D access showing how 2D threadblocks map to the 2D matrix, and also how it maps to the linear memory where the matrix resides. Consecutive threads access consecutive column elements, which are not contiguous in memory.
Figure 5. Uncoalesced 2D access showing how 2D threadblocks map to the 2D matrix, and also how it maps to the linear memory where the matrix resides.  Consecutive threads access consecutive column elements, which are not contiguous in memory.

We can observe this behavior by examining the profiling results below:

coalesced_matrix_access(float *, int, int) (512, 512, 1)x(32, 32, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    ----------------------------------------------- ----------- ------------
    Metric Name                                     Metric Unit Metric Value
    ----------------------------------------------- ----------- ------------
    l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum                  8388608
    l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum       sector     33554432
    ----------------------------------------------- ----------- ------------

  uncoalesced_matrix_access(float *, int, int) (512, 512, 1)x(32, 32, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    ----------------------------------------------- ----------- ------------
    Metric Name                                     Metric Unit Metric Value
    ----------------------------------------------- ----------- ------------
    l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum                  8388608
    l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum       sector    268435456
    ----------------------------------------------- ----------- ------------

Both kernels generate identical numbers of memory requests (8,388,608), but the coalesced version requires only 33,554,432 sectors compared to the uncoalesced version’s 268,435,456 sectors. This translates to a sectors-per-request ratio of 4 for the coalesced kernel versus 32 for the uncoalesced kernel. The coalesced kernel’s low ratio of 4 sectors per request indicates efficient memory coalescing where the GPU can satisfy multiple thread requests within fewer memory sectors due to contiguous access patterns. In contrast, the uncoalesced kernel’s high ratio of 32 sectors per request demonstrates uncoalesced memory accesses, where strided access patterns force the memory subsystem to fetch significantly more sectors than necessary to satisfy the same memory requests.

Summary

Efficient use of GPU memory is one of the most important criteria you need to focus on to obtain the best performance possible. Optimal global memory performance relies on using coalesced memory accesses. Make sure to minimize strided access to global memory, and always profile your GPU kernels with Nsight Compute to ensure that your memory accesses are coalesced. This approach will help you get the most performance possible out of your GPU code.

Acknowledgments 

This post is an update to a post that was originally published in 2013 by Mark Harris of NVIDIA.

Discuss (0)

Tags