Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ $else:
#include "broadcasting_utils.h"
#include "indexing_utils.h"

$if MASK_PADDING:
#define MASK_PADDING

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
Expand Down Expand Up @@ -140,11 +143,29 @@ void main() {
other_texel = other_texel.xxxx;
}

write_texel_lpos(
t_out,
lpos,
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
out_axis_map);
VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha));

#ifdef MASK_PADDING
// Handle padding elements in the last texel to prevent NaN propagation.
// When the packed dimension size is not a multiple of 4, the last texel
// will have padding elements. For division operations, padding elements
// (which are 0/0) can produce NaN values that propagate through reductions.
const int nspill = mod4(out_sizes[packed_dim]);

if (nspill > 0) {
const int texels_per_batch = divup4(out_sizes[packed_dim]);
const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1);

if (is_last_texel) {
// Explicitly set padding elements to 0 to avoid NaN
[[unroll]] for (int i = nspill; i < 4; i++) {
out_texel[i] = 0;
}
}
}
#endif

write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
}

#endif
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ binary_op:
NDIM: 3
DTYPE: float
PACKING: C_packed
MASK_PADDING: 0
generate_variant_forall:
STORAGE:
- VALUE: texture3d
Expand All @@ -26,10 +27,12 @@ binary_op:
OPERATOR: X * Y
- NAME: binary_div
OPERATOR: X / Y
MASK_PADDING: 1
- NAME: binary_pow
OPERATOR: pow(X, Y)
- NAME: binary_floor_divide
OPERATOR: floor(X / Y)
MASK_PADDING: 1
- NAME: binary_minimum
OPERATOR: min(X, Y)
- NAME: binary_eq_int32
Expand Down
46 changes: 46 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,52 @@ def forward(self, x):
(torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),),
)

def test_vulkan_backend_div_with_padding_nan_propagation(self):
"""
Test division operations with non-multiple-of-4 channels followed by convolution.

This test verifies the fix for NaN propagation in padding texels during division.
When the packed dimension (channels=3) is not a multiple of 4, texture-backed
tensors have padding elements in the last texel. Without proper masking, division
operations produce NaN values (0/0) in padding regions that propagate through
subsequent operations like convolution, corrupting results.

This simulates a common real-world pattern: per-channel image normalization
(subtract mean, divide by std) followed by convolution.
"""

class NormalizationConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Per-channel mean and std for normalization (shape: [1, 3, 1, 1])
self.mean = torch.tensor([[[[0.485]], [[0.456]], [[0.406]]]])
self.std = torch.tensor([[[[0.229]], [[0.224]], [[0.215]]]])

# Conv2d layer to process normalized image
self.conv = torch.nn.Conv2d(
in_channels=3, # Non-multiple-of-4 to trigger padding
out_channels=16,
kernel_size=3,
padding=1,
stride=1,
bias=True,
)

def forward(self, x):
# Simulate image normalization: (x - mean) / std
# This is where NaN could appear in padding texels without the fix
x = x - self.mean
x = x / self.std
# Convolution operation that would be corrupted by NaN propagation
x = self.conv(x)
return x

module = NormalizationConvModule()
# Use a typical image tensor size [batch=1, channels=3, height=256, width=256]
sample_inputs = (torch.randn(size=(1, 3, 256, 256), dtype=torch.float32),)

self.lower_module_and_test_output(module, sample_inputs)

def test_vulkan_backend_grid_priors(self):
class GridPriorsModule(torch.nn.Module):
def __init__(self):
Expand Down
Loading