Skip to content

Add hipdnn convolution support#3049

Open
zjgarvey wants to merge 14 commits intoROCm:hipdnn_developfrom
zjgarvey:hipdnn_convolution_2
Open

Add hipdnn convolution support#3049
zjgarvey wants to merge 14 commits intoROCm:hipdnn_developfrom
zjgarvey:hipdnn_convolution_2

Conversation

@zjgarvey
Copy link

@zjgarvey zjgarvey commented Mar 6, 2026

Implements forward and backward convolution (2D/3D) through the hipDNN frontend graph API, providing an alternative to the MIOpen backend on ROCm.

  • Shared utilities: Extract createTensorAttributes() into ATen/hipdnn/Utils.h for reuse across hipDNN ops (BatchNorm, Conv)
  • Graph-cached convolution: Forward (fprop), backward-data (dgrad), and backward-weight (wgrad) via hipDNN frontend graphs with a thread-local LRU cache (ParamsLRUCache<K,V>) to amortize graph->build() cost
  • Dispatch integration: New ConvBackend::Hipdnn and ConvBackend::HipdnnTranspose variants wired through backend selection, memory format selection, forward/backward switches, and Python enum exposure. hipDNN takes priority over MIOpen when torch.backends.hipdnn.enabled is True
  • Bias fusion: Forward conv fuses bias via a pointwise ADD node in the graph, avoiding a separate output.add_() call
  • Transposed convolution: Implemented as dgrad (forward) / fprop (backward-input) / wgrad (backward-weight), with bias applied separately (for now) since dgrad + pointwise graphs aren't supported by any hipDNN backends currently.
  • Grouped/depthwise support: hipDNN infers group count from tensor dimensions; explicit output dims are set on dgrad/wgrad graphs so grouping is resolved correctly

@zjgarvey zjgarvey changed the title Hipdnn convolution 2 [wip] add hipdnn convolution support Mar 6, 2026
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 6, 2026

Jenkins build for 2cea5f50e97a741b501c0fce1667b713966152db commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 9, 2026

Jenkins build for 168eeb1b729889b93ad504773a9b8c3c5be09592 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Copy link
Author

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 11, 2026

Jenkins build for b0ca170b13923c1d321e2aed5f3b2daef7d5dd09 commit finished as NOT_BUILT
Links: Pipeline Overview / Build artifacts / Test Results

@zjgarvey zjgarvey changed the title [wip] add hipdnn convolution support Add hipdnn convolution support Mar 11, 2026
@zjgarvey zjgarvey marked this pull request as ready for review March 11, 2026 14:55
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 11, 2026

Jenkins build for b0ca170b13923c1d321e2aed5f3b2daef7d5dd09 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Detected error during Pytorch building:

[5315/8175] Building CXX object third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_avx512_core_f16_dw_conv_kernel.cpp.o
[5316/8175] Building CXX object third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern_autogen.cpp.o
[5317/8175] Building CXX object third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/gemm/s8x8s32/jit_avx_u8_copy_bn_kern_autogen.cpp.o
[5318/8175] Building CXX object third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_generator.cpp.o
[5319/8175] Building CXX object third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_brgemm_transpose_utils.cpp.o
FAILED: third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_brgemm_transpose_utils.cpp.o 
/opt/cache/bin/sccache /opt/cache/bin/c++ -DDNNL_ENABLE_CPU_ISA_HINTS -DDNNL_ENABLE_ITT_TASKS -DDNNL_ENABLE_MAX_CPU_ISA -DDNNL_X64=1 -DIDEEP_USE_MKL -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DROCM_VERSION=70200 -DTORCH_HIP_VERSION=702 -DUSE_LAYERNORM_FAST_RECIPROCAL -D__STDC_CONSTANT_MACROS -D__STDC_LIMIT_MACROS -I/var/lib/jenkins/pytorch/build/third_party/ideep/mkl-dnn/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/third_party -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src -isystem /opt/rocm-7.2.0/include -isystem /var/lib/jenkins/pytorch/build/third_party/gloo -isystem /var/lib/jenkins/pytorch/cmake/../third_party/gloo -isystem /var/lib/jenkins/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/googletest/googletest/include -isystem /var/lib/jenkins/pytorch/third_party/protobuf/src -isystem /opt/conda/envs/py_3.12/include -isystem /var/lib/jenkins/pytorch/third_party/XNNPACK/include -isystem /var/lib/jenkins/pytorch/third_party/ittapi/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/eigen -isystem /opt/rocm/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -fopenmp -fvisibility-inlines-hidden  -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -Wall -Wno-unknown-pragmas -Wundef -fvisibility=internal   -fPIC -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector-strong -fcf-protection=full  -Wmissing-field-initializers  -Wno-strict-overflow -Wno-maybe-uninitialized -Wno-stringop-overflow -Wno-array-bounds  -O3 -DNDEBUG -DNDEBUG -std=c++20 -fPIC -DMKL_HAS_SBGEMM -DMKL_HAS_SHGEMM -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -MD -MT third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_brgemm_transpose_utils.cpp.o -MF third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_brgemm_transpose_utils.cpp.o.d -o third_party/ideep/mkl-dnn/src/cpu/x64/CMakeFiles/dnnl_cpu_x64.dir/jit_brgemm_transpose_utils.cpp.o -c /var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src/cpu/x64/jit_brgemm_transpose_utils.cpp
during RTL pass: cse_local
/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src/cpu/x64/jit_brgemm_transpose_utils.cpp: In member function ‘void dnnl::impl::cpu::x64::jit_brgemm_copy_to_coarse_t::copy_row_blks(int)’:
/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src/cpu/x64/jit_brgemm_transpose_utils.cpp:1146:1: internal compiler error: Segmentation fault
 1146 | }

// ---------------------------------------------------------------------------
// Cache key: captures everything that determines graph topology
// ---------------------------------------------------------------------------
constexpr int hipdnn_max_dim = 3;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MIOPEN_DIM_MAX = 5 already defined and used

static constexpr int MIOPEN_DIM_MAX = 5;

constexpr int MIOPEN_DIM_MAX = 5;

constexpr size_t MIOPEN_DIM_MAX = 5;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to leave this separate, since I don't want to use any of these named constants from other source files.

E.g., hipdnn might have other backends which support more dim combinations than miopen.

@zjgarvey
Copy link
Author

I think I need a rebase, one sec.

@zjgarvey zjgarvey force-pushed the hipdnn_convolution_2 branch from b0ca170 to 184c4c3 Compare March 11, 2026 21:23
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 11, 2026

Jenkins build for 184c4c3223e16aa7318a8d4f9d82bcad635636c0 commit finished as SUCCESS
Links: Pipeline Overview / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 12, 2026

Jenkins build for a2cb42c0f7117d294a9cbd6ffad6ad27341fc523 commit finished as NOT_BUILT
Links: Pipeline Overview / Build artifacts / Test Results

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 12, 2026

Jenkins build for a2cb42c0f7117d294a9cbd6ffad6ad27341fc523 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Copy link

@mousdahl-amd mousdahl-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing major jumped out at me.

if (!at::globalContext().userEnabledHipdnn()) return false;
if (!detail::getCUDAHooks().compiledWithHipDNN()) return false;
if (!input.is_cuda()) return false;
auto dtype = input.scalar_type();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These datatype / dimension checks are interesting to me. I understand wanting a fast path to skip if hipDNN doesn't support certain datatypes / tensor dimensions, but this could very easily change going forward. I wonder if there's a way we can check in with hipDNN to see if it's got applicable engines instead. I just want to avoid a maintenance headache if possible.

The problem with that is it may be weightier than you want to do.

zjgarvey and others added 9 commits March 17, 2026 10:34
Implement hipDNN-based forward and backward convolution (2D and 3D)
with a thread-local LRU graph cache to amortize the expensive
graph->build() cost. Supports contiguous and channels-last memory
formats, grouped/depthwise configurations, and transposed convolution.

The graph cache follows the cuDNN v8 pattern from Conv_v8.cpp with
configurable size via TORCH_HIPDNN_CONV_LRU_CACHE_LIMIT env var.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add Hipdnn and HipdnnTranspose to ConvBackend enum and wire them
through the full dispatch path: backend selection (use_hipdnn check
inserted before use_miopen), memory format selection, forward switch,
backward switch, and Python enum exposure.

hipDNN takes priority over MIOpen when torch.backends.hipdnn.enabled
is True. The hipdnn_conv_suggest_memory_format() supports NHWC/NDHWC
unconditionally since hipDNN handles strides natively.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The functions used SymIntArrayRef/SymInt parameter types but the
dispatch registration in native_functions.yaml maps to the non-symint
overload, which expects IntArrayRef/int64_t. This caused undefined
reference linker errors in libtorch_hip.so.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Fix cache limit env var: replace check_env() (returns bool) with
  get_env() + stoi() so custom limits are actually parsed
- Remove cudnn_enabled gate from use_hipdnn(); userEnabledHipdnn()
  already handles the enable/disable check independently
- Fuse bias into forward conv graph via pointwise(ADD), eliminating
  the separate output.add_(reshape_bias(...)) call
- Remove unused #include <mutex> (cache is thread-local)
- Unify namespace style to `namespace at::native`
- Add clarifying comment on implicit group count via tensor shapes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add test_hipdnn_conv.py with forward+backward tests for hipDNN conv.
Working: fp32 basic conv, dilation. Xfail: bias (plugin needs conv+bias
+activ 3-node graph), bf16, grouped/strided backward, transposed conv.
Skip: depthwise (GPU fault).

Fix buildConvFpropGraph to set intermediate_data_type on graph and
compute_data_type on pointwise attributes, matching the hipDNN sample
pattern for fused conv+bias. Bias fusion still blocked by MIOpen
legacy plugin only supporting 3-node (conv+bias+activ) graphs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
hipDNN infers group count from tensor dimensions. For wgrad, both graph
inputs (dy and x) have full channel counts, so without explicitly
setting the output weight shape, hipDNN cannot determine the group
count and defaults to groups=1. This causes out-of-bounds GPU memory
access for grouped/depthwise convolutions.

Set output dims on both dgrad and wgrad graphs using the known
input_size and weight_size respectively.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…loat16 numerics.

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Replace the custom HipdnnGraphCache with a shared ParamsLRUCache<K,V>
template, remove stored UIDs from the cached graph struct in favor of
enum constants with semantic aliases, and move contiguous() calls into
the hipDNN entry points so the dispatch site doesn't need them.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Include the graph output dimensions in HipdnnConvParams so that dgrad
graphs built for different output_padding values (e.g. transposed conv
with output_padding=1 vs 0) are cached separately.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
zjgarvey and others added 4 commits March 17, 2026 10:34
Wire bias through buildConvDgradGraph and runHipdnnConvDgrad so
dgrad+pointwise(ADD) fusion is ready when hipDNN backend plugins
support it. For now the transposed conv entry point still applies
bias separately since no plugin handles this pattern yet.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fix cache_limit == 0 to mean "unlimited" (no eviction, no LRU tracking)
instead of the previous behavior where entries were still added to
cache_order unnecessarily. Add TORCH_INTERNAL_ASSERT on eviction erase
to catch cache corruption early.

Plumb benchmark and deterministic flags through to the cache key so that
different flag combinations produce separate cache entries, preparing for
when HipDNN supports algorithm/engine selection based on these flags.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
deterministic=true is a correctness contract that hipDNN cannot honor
yet (no engine-level determinism filtering), so raise an error rather
than silently producing non-deterministic results.

benchmark=true is a performance hint (algorithm search), safe to ignore
but users should know it has no effect — emit TORCH_WARN_ONCE.

Also removes both flags from the cache key since they do not affect
graph construction.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of applying bias separately after the dgrad call, pass it
through to runHipdnnConvDgrad so it can be fused when the backend
supports it.

Signed-off-by: Zach Garvey <zachary.garvey@amd.com>
Assisted-By: Claude Opus 4.6 <noreply@anthropic.com>
@zjgarvey zjgarvey force-pushed the hipdnn_convolution_2 branch from a2cb42c to cd74159 Compare March 17, 2026 17:40
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 17, 2026

Jenkins build for cd7415905e38d51aec57a82df23904df59fd1bb7 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Per-backend forward ops (hipdnn_convolution, hipdnn_convolution_transpose)
are invisible to the compiler -- AOT autograd captures aten.convolution as
an opaque node, never per-backend ops. Removing them from
native_functions.yaml and using dispatch stubs (the same mechanism backward
has used since Joel Schlosser's 2021 refactor) eliminates derivatives.yaml
entries, FC allowlist entries, HasDecomp entries, and trace_rules entries.

hipDNN is the first convolution backend to use dispatch stubs for forward,
proving the pattern works for a potential broader cleanup of cuDNN/MIOpen.

Assisted-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 19, 2026

Jenkins build for 5e5cd8a61c1d5ee5381f87ee5a44f69e9e4a6620 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants