diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index a18f108e473..94c339f643e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -13,105 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -271,36 +178,6 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; - }; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, @@ -354,330 +231,40 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm; // PermuteB not supported by DeviceBatchedGemm base class. - // Argument - struct Argument : public GridwiseGemm::Argument - { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t Batch_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation cde_element_op_, - bool is_reduce_ = false) - : GridwiseGemm::Argument(std::array{p_a_grid_}, - std::array{p_b_grid_}, - std::array{}, // p_ds_grid_ - p_c_grid_, - M_, - N_, - K_, - std::array{StrideA_}, - std::array{StrideB_}, - std::array{}, // StrideDs_ - StrideC_, - k_batch_, - a_element_op_, - b_element_op_, - cde_element_op_, - is_reduce_), - Batch(Batch_), - compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} - { - } - - index_t Batch; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - }; - - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } + using DeviceGemmCommon = DeviceBatchedGemm_Wmma_CShuffleV3_Common< + GridwiseGemm, + Tuple, + Tuple, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // IsBScaled + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } + // Argument + using Argument = typename DeviceGemmCommon::Argument; - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - // TODO: This is not part of the DeviceBatchedGemm base class but it was part of - // DeviceBatchedGemmV2. Remove? - // index_t GetKPerBlock() override { return KPerBlock; } - // bool GetPermuteA() override { return PermuteA; } - // bool GetPermuteB() override { return PermuteB; } - static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, @@ -762,48 +349,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceBatchedGemm_Wmma_CShuffleV3" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " - << "WaveTile: " - << MPerWmma << "x"<(); } REGISTER_EXTRA_PRINTING_METHODS }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index b88f071a962..d682ca4ffa4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -13,109 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_b_scale_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - __shared__ char p_shared[LDS_size]; - - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - const long_index_t b_scale_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_a_scale_grid, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -282,45 +185,6 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale static_assert(PermuteB == false, "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); - struct ComputePtrOffsetOfStridedBatch - { - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideC, - index_t BatchStrideScaleB) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideC_(BatchStrideC), - BatchStrideScaleB_(BatchStrideScaleB) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_) / GridwiseGemm::BPackedSize; - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideScaleB_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; - index_t BatchStrideScaleB_; - }; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< ALayout, @@ -379,328 +243,40 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale PermuteA, // PermuteA not supported by DeviceBatchedGemm base class. PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class. - // Argument - struct Argument : public GridwiseGemm::Argument - { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t StrideScaleB_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t BatchStrideScaleB_, - const BScaleDataType* p_b_scale_grid_, - index_t Batch_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - bool is_reduce_ = false) - : GridwiseGemm::Argument(std::array{p_a_grid_}, - std::array{p_b_grid_}, - std::array{}, // p_ds_grid_ - p_c_grid_, - M_, - N_, - K_, - std::array{StrideA_}, - std::array{StrideB_}, - std::array{}, // StrideDs_ - StrideC_, - 0, // StrideScaleA - StrideScaleB_, - nullptr, - p_b_scale_grid_, - k_batch_, - a_element_op_, - b_element_op_, - c_element_op_, - is_reduce_), - Batch(Batch_), - compute_ptr_offset_of_batch{ - BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_} - { - } - - index_t Batch; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - }; - - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - ck::utility::flush_icache(); - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - throw std::runtime_error("Pipeline not implemented"); - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } + using DeviceGemmCommon = DeviceBatchedGemm_Wmma_CShuffleV3_Common< + GridwiseGemm, + Tuple, + Tuple, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true, // IsBScaled + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GridwiseGemm::BPackedSize, + BScaleDataType>; - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } + // Argument + using Argument = typename DeviceGemmCommon::Argument; - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } index_t GetKPerBlock() override { return KPerBlock; } @@ -801,48 +377,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale // polymorphic std::string GetTypeString() const override { - auto str = std::stringstream(); - - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " - << "WaveTile: " - << MPerWmma << "x"<(); } REGISTER_EXTRA_PRINTING_METHODS }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 00000000000..59a820861c3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,529 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template > +struct DeviceBatchedGemm_Wmma_CShuffleV3_Common +{ + struct ComputePtrOffsetOfStridedBatch + { + template > + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + template > + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + static_assert(BPackedSize != 0); + static_assert(IsBScaled || (!IsBScaled && BPackedSize == 1)); + return g_idx * static_cast(BatchStrideB_) / BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(*BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + std::optional BatchStrideScaleB_; + }; + + struct Argument : public GridwiseGemm::Argument + { + using ADataType = typename AsDataType::DataType; + using BDataType = typename BsDataType::DataType; + template > + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t Batch_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation cde_element_op_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(std::array{p_a_grid_}, + std::array{p_b_grid_}, + std::array{}, // p_ds_grid_ + p_c_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + std::array{}, // StrideDs_ + StrideC_, + k_batch_, + a_element_op_, + b_element_op_, + cde_element_op_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} + { + static_assert(std::is_same_v>); + } + + template > + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(std::array{p_a_grid_}, + std::array{p_b_grid_}, + std::array{}, // p_ds_grid_ + p_c_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + std::array{}, // StrideDs_ + StrideC_, + 0, // StrideScaleA + StrideScaleB_, + nullptr, + p_b_scale_grid_, + k_batch_, + a_element_op_, + b_element_op_, + c_element_op_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_} + { + static_assert(!std::is_same_v>); + } + + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumATensor / GridwiseGemm::APackedSize * + arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumBTensor / GridwiseGemm::BPackedSize * + arg_.Batch; + + ck::utility:: + RotatingMemWrapperMultiABD> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + ck::utility::flush_icache(); + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + using ComputePtrOffsetOfStridedBatch = decltype(arg.compute_ptr_offset_of_batch); + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + else + { + throw std::runtime_error("Pipeline not implemented"); + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + template + static std::string GetTypeString() + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + constexpr auto type = []() { + if constexpr(IsBScaled) + { + return "DeviceBatchedGemm_Wmma_CShuffleV3_BScale"; + } + else + { + return "DeviceBatchedGemm_Wmma_CShuffleV3"; + } + }(); + // clang-format off + str << type + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = EpilogueType{}; + + if constexpr(IsBScaled) + { + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + + splitk_batch_offset.scale_b_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + template