Browse Source

riscv64: Ensure that we use the same vector length when lowering `bitselect+bitcast+{i,f}cmp` (#8133)

We have a special lowering that allows us to fuse a `bitselect` with a comparison instruction. This saves us a few instructions due to the mismatch that exists between native RISC-V masks and WASM masks.

Native RISC-V masks have a single bit per lane, whereas WASM masks have all bits in a lane set to 1.

The lowering for `bitselect+bitcast+{i,f}cmp` avoids the need to generate the WASM mask, by directly using the comparison mask with `vmerge`.

The bug that this fixes was that when we introduce a `bitcast` in the middle, the comparison and the merge may have different types with different lanes. And if that happens the `vmerge` will only look at the first n bits of the mask. n being the number of lanes currently configured.

This commit ensures that they are always equal by using the same type for both vmerge and the comparison instruction.

I also manually checked all other uses of `gen_{f,i}cmp_mask` and they are all using the same type in the subsequent instructions.

With this fix we no longer really care about the type of the `bitselect` as long as it has the same bitlength as the type of `{i,f}cmp`, which I think is enforced by the verifier. (i.e. We would have the same bug if `bitselect.i8x16+icmp.i8xi8` was legal.)
pull/8137/head
Afonso Bordado 8 months ago
committed by GitHub
parent
commit
34f504cd98
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 24
      cranelift/codegen/src/isa/riscv64/lower.isle
  2. 112
      cranelift/filetests/filetests/isa/riscv64/simd-bitselect.clif
  3. 34
      cranelift/filetests/filetests/runtests/simd-bitselect.clif

24
cranelift/codegen/src/isa/riscv64/lower.isle

@ -1838,21 +1838,29 @@
;; ;;
;; This allows us to skip the mask expansion step and use the more efficient ;; This allows us to skip the mask expansion step and use the more efficient
;; vmerge.vvm instruction. ;; vmerge.vvm instruction.
(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (bitselect (icmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b) x y))) ;;
;; We should be careful to ensure that the mask and the vmerge have the
;; same type. So that we don't generate a mask with length 16 (i.e. for i8x16), and then
;; only copy the first few lanes of the result to the destination register because
;; the bitselect has a different length (i.e. i64x2).
;;
;; See: https://github.com/bytecodealliance/wasmtime/issues/8131
(rule 2 (lower (has_type (ty_vec_fits_in_register _ty) (bitselect (icmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b) x y)))
(let ((mask VReg (gen_icmp_mask cmp_ty cc a b))) (let ((mask VReg (gen_icmp_mask cmp_ty cc a b)))
(rv_vmerge_vvm y x mask ty))) (rv_vmerge_vvm y x mask cmp_ty)))
(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (bitselect (fcmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b) x y))) (rule 2 (lower (has_type (ty_vec_fits_in_register _ty) (bitselect (fcmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b) x y)))
(let ((mask VReg (gen_fcmp_mask cmp_ty cc a b))) (let ((mask VReg (gen_fcmp_mask cmp_ty cc a b)))
(rv_vmerge_vvm y x mask ty))) (rv_vmerge_vvm y x mask cmp_ty)))
(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (bitselect (bitcast _ (fcmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b)) x y))) (rule 2 (lower (has_type (ty_vec_fits_in_register _ty) (bitselect (bitcast _ (fcmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b)) x y)))
(let ((mask VReg (gen_fcmp_mask cmp_ty cc a b))) (let ((mask VReg (gen_fcmp_mask cmp_ty cc a b)))
(rv_vmerge_vvm y x mask ty))) (rv_vmerge_vvm y x mask cmp_ty)))
(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (bitselect (bitcast _ (icmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b)) x y))) (rule 2 (lower (has_type (ty_vec_fits_in_register _ty) (bitselect (bitcast _ (icmp cc a @ (value_type (ty_vec_fits_in_register cmp_ty)) b)) x y)))
(let ((mask VReg (gen_icmp_mask cmp_ty cc a b))) (let ((mask VReg (gen_icmp_mask cmp_ty cc a b)))
(rv_vmerge_vvm y x mask ty))) (rv_vmerge_vvm y x mask cmp_ty)))
;;;;; Rules for `isplit`;;;;;;;;; ;;;;; Rules for `isplit`;;;;;;;;;

112
cranelift/filetests/filetests/isa/riscv64/simd-bitselect.clif

@ -410,3 +410,115 @@ block0(v0: i64x2, v1: i64x2, v2: f64x2, v3: f64x2):
; addi sp, sp, 0x10 ; addi sp, sp, 0x10
; ret ; ret
function %bitselect_i8x16_fcmp_f64x2(i8x16) -> i8x16 fast {
const0 = 0x00000000000000000000000000000000
block0(v0: i8x16):
v1 = bitcast.f64x2 little v0
v2 = fcmp eq v1, v1
v3 = bitcast.i8x16 little v2
v4 = vconst.i8x16 const0
v5 = bitselect.i8x16 v3, v0, v4
return v5
}
; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,[const(0)] #avl=16, #vtype=(e8, m1, ta, ma)
; vmfeq.vv v0,v9,v9 #avl=2, #vtype=(e64, m1, ta, ma)
; vmerge.vvm v15,v14,v9,v0.t #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v15,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; auipc t6, 0
; addi t6, t6, 0x34
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x90, 0x94, 0x62
; .byte 0xd7, 0x87, 0xe4, 0x5c
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0xa7, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
function %bitselect_i8x16_icmp_i64x2(i8x16) -> i8x16 fast {
const0 = 0x00000000000000000000000000000000
block0(v0: i8x16):
v1 = bitcast.i64x2 little v0
v2 = icmp eq v1, v1
v3 = bitcast.i8x16 little v2
v4 = vconst.i8x16 const0
v5 = bitselect.i8x16 v3, v0, v4
return v5
}
; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,16(fp) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,[const(0)] #avl=16, #vtype=(e8, m1, ta, ma)
; vmseq.vv v0,v9,v9 #avl=2, #vtype=(e64, m1, ta, ma)
; vmerge.vvm v15,v14,v9,v0.t #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v15,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, s0, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; auipc t6, 0
; addi t6, t6, 0x34
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x80, 0x94, 0x62
; .byte 0xd7, 0x87, 0xe4, 0x5c
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0xa7, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00
; .byte 0x00, 0x00, 0x00, 0x00

34
cranelift/filetests/filetests/runtests/simd-bitselect.clif

@ -94,3 +94,37 @@ block0(v0: i64x2, v1: i64x2, v2: i64x2):
; run: %bitwise_bitselect_i64x2(0x11111111111111111111111111111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x11111111111111111111111111111111 ; run: %bitwise_bitselect_i64x2(0x11111111111111111111111111111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x11111111111111111111111111111111
; run: %bitwise_bitselect_i64x2(0x01010011000011110000000011111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x01010011000011110000000011111111 ; run: %bitwise_bitselect_i64x2(0x01010011000011110000000011111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x01010011000011110000000011111111
; run: %bitwise_bitselect_i64x2(0x00000000000000001111111111111111, 0x00000000000000000000000000000000, 0x11111111111111111111111111111111) == 0x11111111111111110000000000000000 ; run: %bitwise_bitselect_i64x2(0x00000000000000001111111111111111, 0x00000000000000000000000000000000, 0x11111111111111111111111111111111) == 0x11111111111111110000000000000000
;; See issue #8131
;;
;; These tests test the fusion of `bitselect+bitcast+{f,i}cmp` that
;; some backends perform. Importantly the `fcmp` and the `bitselect`
;; have both different type sizes as well as a different number of
;; lanes.
function %bitselect_i8x16_fcmp_f64x2(i8x16) -> i8x16 fast {
const0 = 0x00000000000000000000000000000000
block0(v0: i8x16):
v1 = bitcast.f64x2 little v0
v2 = fcmp eq v1, v1
v3 = bitcast.i8x16 little v2
v4 = vconst.i8x16 const0
v5 = bitselect.i8x16 v3, v0, v4 ; v3 = const0
return v5
}
; run: %bitselect_i8x16_fcmp_f64x2(0x80808080808080808080808080808080) == 0x80808080808080808080808080808080
function %bitselect_i8x16_icmp_i64x2(i8x16) -> i8x16 fast {
const0 = 0x00000000000000000000000000000000
block0(v0: i8x16):
v1 = bitcast.i64x2 little v0
v2 = icmp eq v1, v1
v3 = bitcast.i8x16 little v2
v4 = vconst.i8x16 const0
v5 = bitselect.i8x16 v3, v0, v4 ; v3 = const0
return v5
}
; run: %bitselect_i8x16_icmp_i64x2(0x80808080808080808080808080808080) == 0x80808080808080808080808080808080

Loading…
Cancel
Save