-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Summary
tvm.compile segfaults when compiling a Relax module (imported from torch.export) even with a pure CPU target (llvm -keys=cpu). The crash occurs inside the PTX-specific pass:
tvm::tir::transform::InjectPTXLDG32(bool)tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)tvm::tir::BufferStore::BufferStore(...)
This is unexpected because the target is LLVM/CPU, yet the compilation pipeline still enters a PTX rewriting pass. Removing tir.ptx_ldg32 from PassContext avoids the crash.
This suggests either:
tir.ptx_ldg32enables a PTX-only pass without checking whether the target is CUDA/PTX-capable, orInjectPTXLDG32lacks a defensive early-exit / target predicate and can crash on non-PTX code paths.
Environment
From the repro output:
- TVM:
0.22.0 - Commit:
9dbf3f22ff6f44962472f9af310fda368ca85ef2 - LLVM:
17.0.6 - Python:
3.10.16(from stack paths) - NumPy:
2.2.6 - PyTorch:
2.9.0+cu128
Target used in repro:
llvm -keys=cpu -mtriple=x86_64-unknown-linux-gnu
Reproduction Steps
- Convert a small PyTorch module to Relax via
torch.export.export+tvm.relax.frontend.torch.from_exported_program. - Call
tvm.compileunder a PassContext withconfig={"tir.ptx_ldg32": 1}. - Observe segfault during compilation.
Minimal Repro Script
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tvm
from tvm import tir
def print_env_info():
print("==== Environment Info ====")
print("TVM version:", getattr(tvm, "__version__", "unknown"))
try:
li = tvm.support.libinfo()
print("TVM git commit:", li.get("GIT_COMMIT_HASH", "unknown"))
print("TVM LLVM version:", li.get("LLVM_VERSION", "unknown"))
except Exception:
pass
print("Python (numpy) version:", np.__version__)
print("PyTorch version:", torch.__version__)
print("==========================\n")
class BranchNet(nn.Module):
def __init__(self, k: int):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, k, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1)
self.pool = nn.MaxPool2d(2)
s1 = 28 - k + 1
s2 = s1 - 2
sp = s2 // 2
self.fc = nn.Linear(32 * sp * sp, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.reshape(x.shape[0], -1)
return self.fc(x)
class M(nn.Module):
def __init__(self):
super().__init__()
self.b1 = BranchNet(3)
self.b2 = BranchNet(5)
self.b3 = BranchNet(7)
self.out = nn.Linear(30, 10)
def forward(self, x):
a = self.b1(x)
b = self.b2(x)
c = self.b3(x)
y = self.out(torch.cat([a, b, c], dim=1))
return F.log_softmax(y, dim=1)
def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
mod = mod.to("cpu").eval()
x = x.to("cpu")
ep = torch.export.export(mod, (x,))
from tvm.relax.frontend.torch import from_exported_program
return from_exported_program(ep)
def main():
print_env_info()
target = tvm.target.Target("llvm -keys=cpu -mtriple=x86_64-unknown-linux-gnu")
tir_pipeline = tir.get_default_tir_pipeline(target) # explicit_default
relax_pipeline = "default"
x = torch.rand(1, 1, 28, 28, dtype=torch.float32)
ir_mod = export_to_relax(M(), x)
pc = {
"opt_level": 0,
"disabled_pass": ["LoopPartition"],
"config": {
"tir.ptx_ldg32": 1,
},
}
print("[repro] target:", target)
print("[repro] tir_pipeline: explicit_default")
print("[repro] compiling with tvm.compile ...")
with tvm.transform.PassContext(**pc):
tvm.compile(ir_mod, target=target, relax_pipeline=relax_pipeline, tir_pipeline=tir_pipeline)
if __name__ == "__main__":
main()Actual Behavior
Segfault during compilation. Stack trace shows PTX rewrite pass even though target is LLVM:
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)
Expected Behavior
On a CPU/LLVM target:
- Setting
tir.ptx_ldg32=1should either be ignored (no-op) or rejected with a clear error message, and - PTX-specific passes such as
InjectPTXLDG32should not run for non-CUDA targets, and - TVM should never segfault; failures should be surfaced as Python exceptions with diagnostics.
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug