torch-xla动态shape——通过torch.nonzero分析mhlo实现
- 电脑硬件
- 2025-09-07 22:15:01

pytorch api:
torch.wheretorch.nonzerotorch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
特别注意torch.nonzero的as_tuple参数:
mhlo算法理解:
python脚本脚本来源xla/issues/4432
import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" os.environ["PJRT_DEVICE"] = "XPU" # os.environ["PJRT_DEVICE"] = "GPU" os.environ["XLA_EXPERIMENTAL"]="nonzero" import torch import torch_xla.core.xla_model as xm a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=xm.xla_device()) a2 = torch.nonzero(a1) """ IR { %0 = s64[1,6]{1,0} xla::device_data(), xla_shape=s64[1,6]{1,0}, device=TX8:0 %1 = (s32[<=6,2]{1,0}, s32[]) aten::nonzero(%0), num_outputs=2, xla_shape=(s32[<=6,2]{1,0}, s32[]), ROOT=0 } """ print(torch_xla._XLAC._get_xla_tensors_text([a2])) print(f'{a2.shape=}') # a2.shape=torch.Size([<=6, 2]) print('a2=', a2)运行结果:
mhlo代码 module @SyncTensorsGraph.40 { // %arg0: [[1, 0, 0, 5, 0, 6]] func.func @main(%arg0: tensor<1x6xi64>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> { %0 = mhlo.constant dense<0> : tensor<i64> %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i64>) -> tensor<1x6xi64> %2 = mhlo pare NE, %arg0, %1 : (tensor<1x6xi64>, tensor<1x6xi64>) -> tensor<1x6xi1> %3 = mhlo.convert(%2) : (tensor<1x6xi1>) -> tensor<1x6xi32> // %4: [1, 0, 0, 1, 0, 1] %4 = mhlo.reshape %3 : (tensor<1x6xi32>) -> tensor<6xi32> %5 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x6xi32> // %6: [0, 0, 0, 0, 0, 0], row indices %6 = mhlo.reshape %5 : (tensor<1x6xi32>) -> tensor<6xi32> %7 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x6xi32> // %8: [0, 1, 2, 3, 4, 5], col indices %8 = mhlo.reshape %7 : (tensor<1x6xi32>) -> tensor<6xi32> """ 对所有的operand(%4, %6, %8)分别排序得到3个排序后的tensor,且排序条件一致如下 这里3个operand,对应block中有6个element,对应关系为:一个operand对应2个arg %arg1和%arg2为%4中的元素,所以这里操作的结果为:将%6和%8按照%4中元素的大小关系排序 即:将第0,3,5位置的元素排到最前。排序后的结果为: %4'=%9#0 = [1, 1, 1, 0, 0, 0] %6'=%9#1 = [0, 0, 0, 0, 0, 0] 即%10 %8'=%9#2 = [0, 3, 5, 1, 2, 4] 即%11 """ %9:3 = "mhlo.sort"(%4, %6, %8) ({ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>): // %20总是true %20 = mhlo.constant dense<true> : tensor<i1> // 排序条件: %arg1 > %arg2 %21 = mhlo pare GT, %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1> %22 = "mhlo.select"(%20, %21, %20) : (tensor<i1>, tensor<i1>, tensor<i1>) -> tensor<i1> mhlo.return %22 : tensor<i1> }) {dimension = 0 : i64, is_stable = true} : (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>) -> (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>) %10 = mhlo.reshape %9#1 : (tensor<6xi32>) -> tensor<6x1xi32> %11 = mhlo.reshape %9#2 : (tensor<6xi32>) -> tensor<6x1xi32> // %10和%22 concat后,即为非0元素的indices %12 = "mhlo.concatenate"(%10, %11) {dimension = 1 : i64} : (tensor<6x1xi32>, tensor<6x1xi32>) -> tensor<6x2xi32> // 统计非0 index的个数: reduce sum %13 = mhlo.constant dense<0> : tensor<i32> %14 = "mhlo.broadcast_in_dim"(%13) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<6xi32> %15 = mhlo pare GT, %4, %14 : (tensor<6xi32>, tensor<6xi32>) -> tensor<6xi1> %16 = mhlo.convert(%15) : (tensor<6xi1>) -> tensor<6xi32> %17 = mhlo.reduce(%16 init: %13) across dimensions = [0] : (tensor<6xi32>, tensor<i32>) -> tensor<i32> reducer(%arg1: tensor<i32>, %arg2: tensor<i32>) { %20 = mhlo.add %arg1, %arg2 : tensor<i32> mhlo.return %20 : tensor<i32> } // 0维度为动态维度,维度大小,即为统计的个数 %18 = "mhlo.set_dimension_size"(%12, %17) {dimension = 0 : i64} : (tensor<6x2xi32>, tensor<i32>) -> tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>> %19 = "mhlo.tuple"(%18) {xla_shape = "(s32[<=6,2]{1,0})"} : (tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> return %19 : tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> } }torch-xla动态shape——通过torch.nonzero分析mhlo实现由讯客互联电脑硬件栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“torch-xla动态shape——通过torch.nonzero分析mhlo实现”