-
Notifications
You must be signed in to change notification settings - Fork 791
Add CUDA argmax kernel for LLM sampler #16386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
larryliu0820
wants to merge
1
commit into
main
Choose a base branch
from
gh/larryliu0820/87/head
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # LLM sampler library with optional CUDA support | ||
| # | ||
| # ### Editing this file ### | ||
| # | ||
| # This file should be formatted with | ||
| # ~~~ | ||
| # cmake-format -i CMakeLists.txt | ||
| # ~~~ | ||
| # It should also be cmake-lint clean. | ||
| # | ||
|
|
||
| if(NOT EXECUTORCH_ROOT) | ||
| set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) | ||
| endif() | ||
|
|
||
| include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) | ||
|
|
||
| # If the project is configured to build with CUDA support, build the CUDA | ||
| # sampler library. | ||
| if(EXECUTORCH_BUILD_CUDA) | ||
| find_package(CUDAToolkit QUIET) | ||
| if(CUDAToolkit_FOUND) | ||
| enable_language(CUDA) | ||
| set(CMAKE_CUDA_STANDARD 17) | ||
| set(CMAKE_CUDA_STANDARD_REQUIRED ON) | ||
|
|
||
| # Define CUDA sampler library | ||
| add_library(extension_llm_sampler_cuda STATIC argmax.cu cuda_sampler.cu) | ||
| target_include_directories( | ||
| extension_llm_sampler_cuda | ||
| PUBLIC ${EXECUTORCH_ROOT} | ||
| ${EXECUTORCH_ROOT}/.. | ||
| ${CUDAToolkit_INCLUDE_DIRS} | ||
| ) | ||
| target_compile_definitions(extension_llm_sampler_cuda PUBLIC CUDA_AVAILABLE) | ||
| target_link_libraries(extension_llm_sampler_cuda PUBLIC executorch_core | ||
| CUDA::cudart | ||
| ) | ||
| set_target_properties( | ||
| extension_llm_sampler_cuda | ||
| PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON | ||
| ) | ||
|
|
||
| message( | ||
| STATUS "CUDAToolkit found; building extension_llm_sampler_cuda library" | ||
| ) | ||
|
|
||
| install( | ||
| TARGETS extension_llm_sampler_cuda | ||
| EXPORT ExecuTorchTargets | ||
| DESTINATION ${CMAKE_INSTALL_LIBDIR} | ||
| ) | ||
| else() | ||
| message( | ||
| STATUS | ||
| "CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found" | ||
| ) | ||
| endif() | ||
| endif() | ||
|
|
||
| # Install header files | ||
| install( | ||
| DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ | ||
| DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/llm/sampler | ||
| FILES_MATCHING | ||
| PATTERN "*.h" | ||
| PATTERN "*.cuh" | ||
| PATTERN "test" EXCLUDE | ||
| ) | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <executorch/extension/llm/sampler/argmax.cuh> | ||
| #include <executorch/extension/llm/sampler/cuda_sampler.h> | ||
| #include <executorch/runtime/platform/log.h> | ||
|
|
||
| namespace executorch { | ||
| namespace extension { | ||
| namespace llm { | ||
| namespace cuda { | ||
|
|
||
| // Wrapper function that performs argmax on GPU logits tensor | ||
| // Returns the token index with the highest logit value | ||
| // logits_ptr: pointer to GPU memory containing logits | ||
| // vocab_size: vocabulary size | ||
| // scalar_type: data type of the logits tensor | ||
| // cuda_stream: CUDA stream for async execution (nullptr for default stream) | ||
| // out_token_gpu: pre-allocated GPU memory for output token (int*) | ||
| int32_t argmax_cuda( | ||
| const void* logits_ptr, | ||
| int vocab_size, | ||
| ::executorch::aten::ScalarType scalar_type, | ||
| cudaStream_t cuda_stream, | ||
| int* out_token_gpu) { | ||
| // Launch kernel for single row (batch size 1) | ||
| launch_argmax_vocab_rows( | ||
| logits_ptr, | ||
| scalar_type, | ||
| 1, // rows = 1 | ||
| vocab_size, | ||
| out_token_gpu, | ||
| nullptr, // don't need max logit value | ||
| cuda_stream, | ||
| 256 // threads per block | ||
| ); | ||
|
|
||
| // Copy result back to host | ||
| int32_t token; | ||
| cudaError_t err = cudaMemcpyAsync( | ||
| &token, out_token_gpu, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); | ||
| if (err != cudaSuccess) { | ||
| ET_LOG( | ||
| Error, | ||
| "Failed to copy argmax result from GPU: %s", | ||
| cudaGetErrorString(err)); | ||
| return -1; | ||
| } | ||
|
|
||
| // Synchronize to ensure result is ready | ||
| err = cudaStreamSynchronize(cuda_stream); | ||
| if (err != cudaSuccess) { | ||
| ET_LOG( | ||
| Error, | ||
| "Failed to synchronize CUDA stream: %s", | ||
| cudaGetErrorString(err)); | ||
| return -1; | ||
| } | ||
|
|
||
| return token; | ||
| } | ||
|
|
||
| } // namespace cuda | ||
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_runtime.h> | ||
| #include <float.h> | ||
| #include <stdint.h> | ||
|
|
||
| #include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
|
|
||
| namespace executorch { | ||
| namespace extension { | ||
| namespace llm { | ||
| namespace cuda { | ||
|
|
||
| struct ArgMaxPair { | ||
| float v; | ||
| int i; | ||
| }; | ||
|
|
||
| // tie-break: smaller index wins on equal values | ||
| __device__ __forceinline__ ArgMaxPair better(ArgMaxPair a, ArgMaxPair b) { | ||
| if (b.v > a.v) | ||
| return b; | ||
| if (b.v < a.v) | ||
| return a; | ||
| return (b.i < a.i) ? b : a; | ||
| } | ||
|
|
||
| __device__ __forceinline__ ArgMaxPair | ||
| warp_argmax_xor(ArgMaxPair x, unsigned mask = 0xffffffffu) { | ||
| for (int d = 16; d > 0; d >>= 1) { | ||
| ArgMaxPair y; | ||
| y.v = __shfl_xor_sync(mask, x.v, d); | ||
| y.i = __shfl_xor_sync(mask, x.i, d); | ||
| x = better(x, y); | ||
| } | ||
| return x; | ||
| } | ||
|
|
||
| // ---- dtype -> float load helpers ---- | ||
| template <typename T> | ||
| __device__ __forceinline__ float load_as_float(const T* p); | ||
|
|
||
| template <> | ||
| __device__ __forceinline__ float load_as_float<float>(const float* p) { | ||
| return *p; | ||
| } | ||
|
|
||
| template <> | ||
| __device__ __forceinline__ float load_as_float<half>(const half* p) { | ||
| return __half2float(*p); | ||
| } | ||
|
|
||
| template <> | ||
| __device__ __forceinline__ float | ||
| load_as_float<nv_bfloat16>(const nv_bfloat16* p) { | ||
| return __bfloat162float(*p); | ||
| } | ||
|
|
||
| // logits: [rows, vocab] row-major contiguous | ||
| // out_token: [rows] | ||
| // out_maxlogit: [rows] (optional; pass nullptr if not needed) | ||
| template <typename T> | ||
| __global__ void argmax_vocab_rows_kernel( | ||
| const T* __restrict__ logits, | ||
| int rows, | ||
| int vocab, | ||
| int* __restrict__ out_token, | ||
| float* __restrict__ out_maxlogit) { | ||
| int row = blockIdx.x; | ||
| if (row >= rows) | ||
| return; | ||
|
|
||
| int tid = threadIdx.x; | ||
| int lane = tid & 31; | ||
| int warp = tid >> 5; | ||
| int warps_per_block = (blockDim.x + 31) >> 5; | ||
|
|
||
| const T* row_ptr = logits + (size_t)row * (size_t)vocab; | ||
|
|
||
| // local scan | ||
| ArgMaxPair best; | ||
| best.v = -FLT_MAX; | ||
| best.i = -1; | ||
|
|
||
| for (int j = tid; j < vocab; j += blockDim.x) { | ||
| float v = load_as_float<T>(row_ptr + j); | ||
| best = better(best, ArgMaxPair{v, j}); | ||
| } | ||
|
|
||
| // warp reduce | ||
| best = warp_argmax_xor(best); | ||
|
|
||
| // shared collect warp winners (supports up to 1024 threads = 32 warps) | ||
| __shared__ float s_val[32]; | ||
| __shared__ int s_idx[32]; | ||
|
|
||
| if (lane == 0) { | ||
| s_val[warp] = best.v; | ||
| s_idx[warp] = best.i; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // first warp reduces warp winners | ||
| if (warp == 0) { | ||
| ArgMaxPair wbest; | ||
| if (lane < warps_per_block) { | ||
| wbest.v = s_val[lane]; | ||
| wbest.i = s_idx[lane]; | ||
| } else { | ||
| wbest.v = -FLT_MAX; | ||
| wbest.i = -1; | ||
| } | ||
|
|
||
| wbest = warp_argmax_xor(wbest); | ||
|
|
||
| if (lane == 0) { | ||
| out_token[row] = wbest.i; | ||
| if (out_maxlogit) | ||
| out_maxlogit[row] = wbest.v; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| inline void launch_argmax_vocab_rows( | ||
| const void* logits, | ||
| ::executorch::aten::ScalarType scalar_type, | ||
| int rows, | ||
| int vocab, | ||
| int* out_token, | ||
| float* out_maxlogit, | ||
| cudaStream_t stream, | ||
| int threads = 256) { | ||
| dim3 block(threads); | ||
| dim3 grid(rows); | ||
|
|
||
| switch (scalar_type) { | ||
| case ::executorch::aten::ScalarType::Float: | ||
| argmax_vocab_rows_kernel<float><<<grid, block, 0, stream>>>( | ||
| (const float*)logits, rows, vocab, out_token, out_maxlogit); | ||
| break; | ||
| case ::executorch::aten::ScalarType::Half: | ||
| argmax_vocab_rows_kernel<half><<<grid, block, 0, stream>>>( | ||
| (const half*)logits, rows, vocab, out_token, out_maxlogit); | ||
| break; | ||
| case ::executorch::aten::ScalarType::BFloat16: | ||
| argmax_vocab_rows_kernel<nv_bfloat16><<<grid, block, 0, stream>>>( | ||
| (const nv_bfloat16*)logits, rows, vocab, out_token, out_maxlogit); | ||
| break; | ||
| default: | ||
| // Unsupported type, fall back to float | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhapes we need to raise error here to avoid silent error? |
||
| argmax_vocab_rows_kernel<float><<<grid, block, 0, stream>>>( | ||
| (const float*)logits, rows, vocab, out_token, out_maxlogit); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| } // namespace cuda | ||
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| cmake_minimum_required(VERSION 3.19) | ||
| project(llm_sampler_cuda_tests LANGUAGES CXX CUDA) | ||
|
|
||
| set(CMAKE_CXX_STANDARD 17) | ||
| set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
| set(CMAKE_CUDA_STANDARD 17) | ||
| set(CMAKE_CUDA_STANDARD_REQUIRED ON) | ||
|
|
||
| # Find required packages | ||
| find_package(CUDAToolkit REQUIRED) | ||
|
|
||
| # Fetch GoogleTest | ||
| include(FetchContent) | ||
| FetchContent_Declare( | ||
| googletest | ||
| GIT_REPOSITORY https://github.com/google/googletest.git | ||
| GIT_TAG v1.14.0 | ||
| ) | ||
| # For Windows: Prevent overriding the parent project's compiler/linker settings | ||
| set(gtest_force_shared_crt | ||
| ON | ||
| CACHE BOOL "" FORCE | ||
| ) | ||
| FetchContent_MakeAvailable(googletest) | ||
|
|
||
| # Get EXECUTORCH_ROOT | ||
| if(NOT EXECUTORCH_ROOT) | ||
| set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) | ||
| endif() | ||
|
|
||
| # Find installed ExecuTorch | ||
| find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) | ||
|
|
||
| # List of CUDA test files | ||
| set(LLM_SAMPLER_CUDA_TESTS test_argmax) | ||
|
|
||
| enable_testing() | ||
|
|
||
| foreach(test_name ${LLM_SAMPLER_CUDA_TESTS}) | ||
| add_executable(${test_name} ${test_name}.cu) | ||
|
|
||
| target_include_directories( | ||
| ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} | ||
| ${CUDAToolkit_INCLUDE_DIRS} | ||
| ) | ||
|
|
||
| target_link_libraries( | ||
| ${test_name} | ||
| PRIVATE GTest::gtest | ||
| GTest::gtest_main | ||
| executorch_core | ||
| CUDA::cudart | ||
| ) | ||
|
|
||
| add_test(NAME ${test_name} COMMAND ${test_name}) | ||
| endforeach() | ||
|
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using nested namespace to follow c++ 17 standard