Data Science

Running Python UDFs in Native NVIDIA CUDA Kernels with the RAPIDS cuDF

In this post, I introduce a design and implementation of a framework within RAPIDS cuDF that enables compiling Python user-defined functions (UDF) and inlining them into native CUDA kernels. This framework uses the Numba Python compiler and Jitify CUDA just-in-time (JIT) compilation library to provide cuDF users the flexibility of Python with the performance of CUDA as a compiled language.

An essential part of the framework is a parser that parses a CUDA PTX function, which is compiled from the Python UDF, into an equivalent CUDA C++ device function that can be inlined into native CUDA C++ kernels. This approach makes it possible for Python users without CUDA programming knowledge to extend optimized DataFrame operations with their own Python UDFs and enables more flexibility and generality for high-performance computations on DataFrames in RAPIDS.

I start by giving examples on how to use the feature, followed by the goals. Finally, I explain how things work in the background to make the feature possible.

Using the feature

The feature is built into the framework of RAPIDS cuDF and is easy to use. After a DataFrame is created, call the interfaces that support this feature with the user-defined Python function. Currently, the list of support includes:

  • applymap, which applies a UDF to each of the elements.
  • rolling, which applies a range-based UDF to each of the windows.

The applymap example:

>>> import cudf
>>> import cudf.core
>>> from cudf.core import Series
>>> import numpy as np
>>> a = Series([9, 16, 25, 36, 49], dtype=np.float64)
>>> a.applymap(lambda x: x ** 2)
0      81.0
1     256.0
2     625.0
3    1296.0
4    2401.0
dtype: float64
>>> a.applymap(lambda x: 1 if x in [9, 44] else 2)
0    1
1    2
2    2
3    2
4    2
dtype: int64

The rolling example:

>>> def foo(A):
...     sum = 0
...     for a in A:
...         sum = sum + a
...     return sum
...
>>> a.rolling(3, 1, False).apply(foo)
0      9.0
1     25.0
2     50.0
3     77.0
4    110.0
dtype: float64

Flexibility and performance

The Python/CUDA framework uses JIT to achieve flexibility and performance while performing DataFrame operations.

AOT compilation

Traditionally, with ahead-of-time (AOT) compilation, CUDA kernels are compiled into SASS machine-level code at compile time and launched at runtime.

In cases where operator functions must be called by kernels, the use of function pointers or stack frame, which usually jeopardizes performance, are avoided by inlining the operator function, as shown in the following code. This is an illustration, as the actual inlining happens at NVVM IR level.

Code before inlining:  

__device__ inline int op(int a, int b){
  return a + b;
}

__global__ void kernel(int *C, int *A, int *B){
  int idx = threadIdx.x;
  C[idx] = op(A[idx], B[idx]);
}

Code after inlining:

__global__ void kernel_after_inlining(int *C, int *A, int *B){
  int idx = threadIdx.x;
  C[idx] = A[idx] + B[idx];
}

However, performance is achieved at the price of flexibility. At compile time, the operator function is often not known. In most cases, the program does not reach end users until runtime and it is the users who decide what operator function is needed. With AOT compilation, you do not have the ability to write your own operator function without recompiling the whole program while still having the maximum performance.

JIT compilation

JIT compilation, or runtime compilation, comes to help. Using CUDA runtime compilation (NVRTC) and the Jitify library, you can inline the code string of the operator function written at runtime into the code string of the kernel, before the combination is compiled at runtime. Launch it with the same performance of a corresponding traditional, native CUDA kernel. Flexibility and performance are achieved at the same time, with the only overhead being the time needed to perform the runtime compilation.

Combine Python and CUDA

Combining Python, with its flexibility as an interpreted language, and CUDA, with its performance as a compiled language, you get broader coverage. You can write a Python UDF without any knowledge or even awareness of CUDA, and this feature compiles and inlines it into carefully optimized, predefined CUDA kernels, and then launches it on NVIDIA GPUs with maximum performance, as shown in the usage examples.

Performance benchmark for applymap

For DataFrames with large numbers of rows, I compared the performance of pandas.apply with cudf.applymap. The latter can achieve significant speed up over the former. The following benchmark is measured on an Intel Xeon Gold 6128 CPU and an NVIDIA Quadro GV100 GPU. These results do not include the overhead of JIT compilation, a one-time cost paid only on the first execution of this feature using a specific UDF.

Number of rowspandas [v1.0.3] [s]cudf [branch-0.14] [s]Speed up
10**70.3130.0654.9
10**83.1230.06846.1
10**931.2440.072435.2
Table 1. Benchmark results.

The following code blocks show the Python code used to produce this benchmark.

Pandas:

import pandas as pd
import numpy as np
 
s = pd.Series(np.random.randint(1, 101, 10**7))
s.apply(lambda x: x in [1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980])

CUDF:

import cudf
import numpy as np
 
s = cudf.Series(np.random.randint(1, 101, 10**7))
s.applymap(lambda x: x in [1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980, 1027, 1000, 59, 980])

How this feature works in the background

The Python UDF is first compiled into a CUDA PTX function, then backward compiled into a corresponding CUDA C++ function, which is finally inlined into the CUDA kernels.

Numba: A high-performance Python compiler

In addition to the numba JIT functionality that allows you to write CUDA C++ kernels in Python, Numba also has the feature, based on LLVM/NVVM, to compile Python functions into CUDA PTX device functions.

To provide some background, the CUDA PTX assembly stage is an intermediate compilation stage between the initial CUDA C++ and final SASS stages, the latter of which is to be directly launched on the GPUs. All input CUDA C++ code is first compiled into CUDA PTX.

The following code block is an example CUDA C++ device function:

__device__ void foo(
      float* p0,
      float  f1,
      float  f2
)
{
   *p0 = f1 + f2;
}

It is compiled into the CUDA PTX assembly shown in the following code block:

.func _Z3fooPfff(
      .param .b64 _param_p0,
      .param .b32 _param_f1,
      .param .b32 _param_f2
)
{
      .reg .f32 %f<4>;
      .reg .b64 %rd<3>;
 
      ld.param.u64 %rd1, [_param_p0];
      ld.param.f32 %f1, [_param_f1];
      ld.param.f32 %f2, [_param_f2];

      add.f32 %f3, %f1, %f2;
      st.f32 [%rd1], %f3;

      ret;
}

Without further work, the compiled PTX functions are ready to be called from CUDA PTX kernels. This is done through function pointers that inflict performance loss due to the use of stack frames. As discussed earlier in this post, the way around is to inline the functions into the CUDA C++ kernels, which is currently not possible for a CUDA PTX device function.

A natural way to solve the problem is to backward compile the CUDA PTX device functions into CUDA C++ device functions, which can be inlined into CUDA C++ kernels by the nvcc compiler directly. The bulk work that makes this feature possible is the automatic workflow to perform this backward compilation.

Backward compiler from CUDA PTX to CUDA C++

CUDA supports the inline PTX syntax that allows you to write the CUDA PTX assembly in CUDA C++ code.

__device__ void foo(
      float* p0,
      float  f1,
      float  f2
)
{
      // does the same as `*p0 = f1 + f2;` but using inline PTX
      asm(“add.f32 %0, %1, %2;” : “=f”(*p0) : “f”(f1), “f”(f2));
}

The CUDA PTX code inlined in the CUDA C++ code is copied literally into the PTX stage of the compilation, at which point it is combined with the other PTX code compiled from the rest of CUDA C++ code.

A general picture

The inline PTX syntax does most of the work in the workflow from CUDA PTX to CUDA C++. Quoting the PTX code and using the inline PTX syntax mostly gives you the CUDA C++ code for free. There is no free lunch, though. There are some subtleties that must be taken care of. Here’s how special cases are dealt with.

To make things concrete, here’s an example CUDA PTX input.

.visible .func  (.param .b32 func_retval0) _Z7add$241Eff(
           .param .b64 _param_0,        // float* p0
           .param .b32 _param_1,        // float  f1
           .param .b32 _param_2         // float  f2
){
 .reg .f32            %f<5>;
 .reg .b32            %r<2>;
 .reg .b64            %rd<2>;


  ld.param.u64 %rd1, [_param_0];
  ld.param.f32 %f1,  [_param_1];
  ld.param.f32 %f2,  [_param_2];
  mul.f32 %f3, %f1, %f1;                 // f3 = f1 * f1
  fma.rn.f32          %f4, %f3, %f1, %f2; // f4 = f3 * f1 + f2
  st.f32   [%rd1], %f4;                   // *p0 = f4
  mov.u32 %r1, 0;
  st.param.b32        [func_retval0+0], %r1;
  ret;
}

The output CUDA C++:

__device__ __inline__ void GENERIC_BINARY_OP(
  float* _param_0,
  float  _param_1,
  float  _param_2
){
  asm volatile ("  .reg .f32 _f<5>;");
  asm volatile ("  .reg .b32 _r<2>;");
  asm volatile ("  .reg .b64 _rd<2>;");

 
  asm volatile ("  mov.u64 _rd1,  %0;": : "l"(_param_0));
  asm volatile ("  mov.f32 _f1,  %0;" : : "f"(_param_1));
  asm volatile ("  mov.f32 _f2,  %0;" : : "f"(_param_2));
  asm volatile ("  mul.f32 _f3, _f1, _f1;");
  asm volatile ("  fma.rn.f32 _f4, _f3, _f1, _f2;");
  asm volatile ("  st.f32 [_rd1], _f4;");
  asm volatile ("  mov.u32 _r1, 0;");
           //
           //
}

Function header

CUDA PTX and CUDA C++ have different grammar on how the function header should be written. In practice, it is assumed that the CUDA PTX function has a dummy return value. The actual output value of the function is to be written to the memory to which the first function parameter points, as it is always interpreted as a pointer. The numba-compiled CUDA PTX functions satisfy this assumption. Because the return type of the CUDA PTX function is a dummy, the output CUDA C++ functions always have a void return type.

The output type must be known from the user. For CUDA PTX function parameters, there is no way to interpret the type of memory it points to. The necessity of this information is expected as the forward compiling process loses information during the process. To go backward, the process needs additional user input.

The rest of the function parameters (starting from the second parameter on) are all considered the input value of the function. User input for their types is not needed, as these types can be inferred from the function parameter loading instructions in the function body. As an exception to that, the user must tell the workflow if any of the rest of the function parameters are pointers, because in the CUDA PTX language, on a 64-bit machine, pointers are 64-bit registers.

Register declaration

The register declarations in the CUDA PTX function are translated verbosely with the inline PTX syntax. However, ampersands (%) in the register names must be replaced with other characters, as ampersands are not allowed in the inline PTX syntax.

Function parameter loading instructions

The function parameter loading instructions, in a CUDA PTX function, load function parameters into registers that can be used later. They can’t be translated verbosely, as the same function is to be inlined and no longer exists. Instead, the PTX instruction mov (move) is used to assign the CUDA C++ function parameters into the same register. An example is shown later.

In addition, the types of these loading instructions tell the types of the function parameters.

In some cases, the earlier translation does not work. For function parameter loading instructions, it is possible that the destination register is wider than the width indicated by the instruction type. For example, that is the case in the following code example:

.reg .b64 %rd5;
ld.param.s32   %rd5, [param_to_be_loaded];
->>> asm volatile ("mov.s32 _rd5,  %0;": : "r"(param_to_be_loaded)); // PTX ERROR!

However, the mov instruction does not allow this register width mismatch. This is worked around by using the cvt (convert) instruction, which allows the mismatch:

.reg .b64 %rd5;
ld.param.s32   %rd5, [param_to_be_loaded];
->>> asm volatile ("cvt.s32.s32 _rd5,  %0;": : "r"(param_to_be_loaded)); // WORKS.

Function parameter storing instructions

The function parameter storing instructions are ignored in this workflow, as you do not use the function return value to send the output value.

Return instruction

In the CUDA PTX language, the ret (return) instruction exits the current function. If ret appears in the input CUDA PTX function, it means that you should exit this function. However, if you translate ret verbosely into the output CUDA C++ function, which is to be inlined into a CUDA kernel, ret tells the kernel function to exit. This is not the original intention of ret.

For ret, follow the original intention: a bra (branch) instruction is used to tell the compiler to only jump to the end of the current function, as shown in the following example:

...

ret;
// The device function is supposed to exit

... // Other non-trivial stuff that should
    // not be executed
}
...

asm ("bra: RETTGT;");

... // skipped

asm ("RETTGT: ");
}

For Python lambda functions such as lambda x: 8 if x > 9 else 1, for which there are multiple return instructions in the CUDA PTX function, this treatment is necessary to make things work.

Conclusion

The Python/CUDA JIT compilation in RAPIDS cuDF allows you to apply your own Python function on DataFrames on NVIDIA GPUs with great flexibility while achieving maximum performance. Combining Python and JIT compilation, as well as backward compiling CUDA PTX functions into CUDA C++ functions, this feature applies beyond the scope of DataFrame extract, transform, load (ETL) and has potentially many more use cases.

To learn more, see my GTC 2020 talk, GTC 2020: Combined Python/CUDA JIT for Flexible Acceleration in RAPIDS. To try this feature out with cuDF, see the rapidsai/cudf GitHub repo.

Discuss (0)

Tags