Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
}

void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130100
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is "
#if CUBLAS_VERSION < 130200
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
Expand Down Expand Up @@ -267,7 +267,7 @@ void run_grouped_gemm_case(const TestParams& params) {
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130100
#endif // CUBLAS_VERSION >= 130200
}

class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
Expand Down
33 changes: 18 additions & 15 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cuda::cublas_version() >= 120800,
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
transformer_engine::cuda::cublas_version());

// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
Expand All @@ -518,7 +518,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,

// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cuda::cublas_version() <= 120803) {
if (transformer_engine::cuda::cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
Expand All @@ -530,9 +530,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cuda::cublas_version() >= 120800,
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
cuda::cublas_version());
transformer_engine::cuda::cublas_version());

// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
Expand Down Expand Up @@ -567,9 +567,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUBLAS_VERSION >= 120900
NVTE_CHECK(cuda::cublas_version() >= 120900,
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cuda::cublas_version());
transformer_engine::cuda::cublas_version());

// Check that matrix formats are valid
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
Expand Down Expand Up @@ -602,7 +602,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}

#if CUBLAS_VERSION >= 120800
if (cuda::cublas_version() >= 120800) {
if (transformer_engine::cuda::cublas_version() >= 120800) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
Expand All @@ -619,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUBLAS_VERSION >= 120800
if (cuda::cublas_version() >= 120800) {
if (transformer_engine::cuda::cublas_version() >= 120800) {
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
Expand Down Expand Up @@ -703,12 +703,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#else
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
NVTE_CHECK(transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
transformer_engine::cuda::cudart_version());
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120205 &&
transformer_engine::cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cuda::cublas_version());
transformer_engine::cuda::cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
Expand Down Expand Up @@ -934,9 +936,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000,
transformer_engine::cuda::cublas_version() >= 120205 &&
transformer_engine::cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cuda::cublas_version());
transformer_engine::cuda::cublas_version());

const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Expand Down
20 changes: 10 additions & 10 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) {

} // namespace

#if CUBLAS_VERSION >= 130100
#if CUBLAS_VERSION >= 130200

namespace {

Expand Down Expand Up @@ -543,13 +543,13 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;

// Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.1+
const int current_device = cuda::current_device();
NVTE_CHECK(cuda::sm_arch(current_device) >= 100,
// Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+
const int current_device = transformer_engine::cuda::current_device();
NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100,
"nvte_grouped_gemm requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(cuda::cublas_version() >= 130100,
"nvte_grouped_gemm requires cuBLAS 13.1+, but run-time cuBLAS version is ",
cuda::cublas_version());
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 130200,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't there be more than one places to add this transformer_engine::?

"nvte_grouped_gemm requires cuBLAS 13.2+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());

// Convert to internal types
const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A);
Expand Down Expand Up @@ -631,15 +631,15 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
kGroupedGemmCublasWorkspaceSize, stream));
}

#else // CUBLAS_VERSION < 130100
#else // CUBLAS_VERSION < 130200

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream) {
NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ",
NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ",
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error message says "upgrade to CUDA 13.1 or newer" but should say "CUDA 13.2 or newer" to match the cuBLAS 13.2+ requirement

Suggested change
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer.");

}

#endif // CUBLAS_VERSION >= 130100
#endif // CUBLAS_VERSION >= 130200
4 changes: 2 additions & 2 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
Expand All @@ -322,7 +322,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.1+ (CUDA 13.1+)
* - cuBLAS 13.2+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
Expand Down
Loading