Skip to content

[Bug] Segfault in tvm.compile on **LLVM (CPU) target** when tir.ptx_ldg32=1: unexpectedly runs tir::transform::InjectPTXLDG32 / PTXRewriter and crashes in BufferStore #18617

@tinywisdom

Description

@tinywisdom

Summary

tvm.compile segfaults when compiling a Relax module (imported from torch.export) even with a pure CPU target (llvm -keys=cpu). The crash occurs inside the PTX-specific pass:

  • tvm::tir::transform::InjectPTXLDG32(bool)
  • tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)
  • tvm::tir::BufferStore::BufferStore(...)

This is unexpected because the target is LLVM/CPU, yet the compilation pipeline still enters a PTX rewriting pass. Removing tir.ptx_ldg32 from PassContext avoids the crash.

This suggests either:

  1. tir.ptx_ldg32 enables a PTX-only pass without checking whether the target is CUDA/PTX-capable, or
  2. InjectPTXLDG32 lacks a defensive early-exit / target predicate and can crash on non-PTX code paths.

Environment

From the repro output:

  • TVM: 0.22.0
  • Commit: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
  • LLVM: 17.0.6
  • Python: 3.10.16 (from stack paths)
  • NumPy: 2.2.6
  • PyTorch: 2.9.0+cu128

Target used in repro:

llvm -keys=cpu -mtriple=x86_64-unknown-linux-gnu

Reproduction Steps

  1. Convert a small PyTorch module to Relax via torch.export.export + tvm.relax.frontend.torch.from_exported_program.
  2. Call tvm.compile under a PassContext with config={"tir.ptx_ldg32": 1}.
  3. Observe segfault during compilation.

Minimal Repro Script

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tvm
from tvm import tir


def print_env_info():
    print("==== Environment Info ====")
    print("TVM version:", getattr(tvm, "__version__", "unknown"))
    try:
        li = tvm.support.libinfo()
        print("TVM git commit:", li.get("GIT_COMMIT_HASH", "unknown"))
        print("TVM LLVM version:", li.get("LLVM_VERSION", "unknown"))
    except Exception:
        pass
    print("Python (numpy) version:", np.__version__)
    print("PyTorch version:", torch.__version__)
    print("==========================\n")


class BranchNet(nn.Module):
    def __init__(self, k: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, k, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.pool = nn.MaxPool2d(2)

        s1 = 28 - k + 1
        s2 = s1 - 2
        sp = s2 // 2
        self.fc = nn.Linear(32 * sp * sp, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        return self.fc(x)


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = BranchNet(3)
        self.b2 = BranchNet(5)
        self.b3 = BranchNet(7)
        self.out = nn.Linear(30, 10)

    def forward(self, x):
        a = self.b1(x)
        b = self.b2(x)
        c = self.b3(x)
        y = self.out(torch.cat([a, b, c], dim=1))
        return F.log_softmax(y, dim=1)


def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
    mod = mod.to("cpu").eval()
    x = x.to("cpu")
    ep = torch.export.export(mod, (x,))
    from tvm.relax.frontend.torch import from_exported_program
    return from_exported_program(ep)


def main():
    print_env_info()

    target = tvm.target.Target("llvm -keys=cpu -mtriple=x86_64-unknown-linux-gnu")
    tir_pipeline = tir.get_default_tir_pipeline(target)  # explicit_default
    relax_pipeline = "default"

    x = torch.rand(1, 1, 28, 28, dtype=torch.float32)
    ir_mod = export_to_relax(M(), x)

    pc = {
        "opt_level": 0,
        "disabled_pass": ["LoopPartition"],
        "config": {
            "tir.ptx_ldg32": 1,
        },
    }

    print("[repro] target:", target)
    print("[repro] tir_pipeline: explicit_default")
    print("[repro] compiling with tvm.compile ...")
    with tvm.transform.PassContext(**pc):
        tvm.compile(ir_mod, target=target, relax_pipeline=relax_pipeline, tir_pipeline=tir_pipeline)


if __name__ == "__main__":
    main()

Actual Behavior

Segfault during compilation. Stack trace shows PTX rewrite pass even though target is LLVM:

tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)

Expected Behavior

On a CPU/LLVM target:

  1. Setting tir.ptx_ldg32=1 should either be ignored (no-op) or rejected with a clear error message, and
  2. PTX-specific passes such as InjectPTXLDG32 should not run for non-CUDA targets, and
  3. TVM should never segfault; failures should be surfaced as Python exceptions with diagnostics.

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions