Blob Blame History Raw
From 9ef733cbe224b1cc12e4c8acac09627ccb3a00d8 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n@gmail.com>
Date: Thu, 21 Apr 2022 20:52:30 -0500
Subject: [PATCH] x86: Optimize {str|wcs}rchr-evex

The new code unrolls the main loop slightly without adding too much
overhead and minimizes the comparisons for the search CHAR.

Geometric Mean of all benchmarks New / Old: 0.755
See email for all results.

Full xcheck passes on x86_64 with and without multiarch enabled.
Reviewed-by: H.J. Lu <hjl.tools@gmail.com>

(cherry picked from commit c966099cdc3e0fdf92f63eac09b22fa7e5f5f02d)
---
 sysdeps/x86_64/multiarch/strrchr-evex.S | 471 +++++++++++++++---------
 1 file changed, 290 insertions(+), 181 deletions(-)

diff --git a/sysdeps/x86_64/multiarch/strrchr-evex.S b/sysdeps/x86_64/multiarch/strrchr-evex.S
index f920b5a5..f5b6d755 100644
--- a/sysdeps/x86_64/multiarch/strrchr-evex.S
+++ b/sysdeps/x86_64/multiarch/strrchr-evex.S
@@ -24,242 +24,351 @@
 #  define STRRCHR	__strrchr_evex
 # endif
 
-# define VMOVU		vmovdqu64
-# define VMOVA		vmovdqa64
+# define VMOVU	vmovdqu64
+# define VMOVA	vmovdqa64
 
 # ifdef USE_AS_WCSRCHR
+#  define SHIFT_REG	esi
+
+#  define kunpck	kunpckbw
+#  define kmov_2x	kmovd
+#  define maskz_2x	ecx
+#  define maskm_2x	eax
+#  define CHAR_SIZE	4
+#  define VPMIN	vpminud
+#  define VPTESTN	vptestnmd
 #  define VPBROADCAST	vpbroadcastd
-#  define VPCMP		vpcmpd
-#  define SHIFT_REG	r8d
+#  define VPCMP	vpcmpd
 # else
+#  define SHIFT_REG	edi
+
+#  define kunpck	kunpckdq
+#  define kmov_2x	kmovq
+#  define maskz_2x	rcx
+#  define maskm_2x	rax
+
+#  define CHAR_SIZE	1
+#  define VPMIN	vpminub
+#  define VPTESTN	vptestnmb
 #  define VPBROADCAST	vpbroadcastb
-#  define VPCMP		vpcmpb
-#  define SHIFT_REG	ecx
+#  define VPCMP	vpcmpb
 # endif
 
 # define XMMZERO	xmm16
 # define YMMZERO	ymm16
 # define YMMMATCH	ymm17
-# define YMM1		ymm18
+# define YMMSAVE	ymm18
+
+# define YMM1	ymm19
+# define YMM2	ymm20
+# define YMM3	ymm21
+# define YMM4	ymm22
+# define YMM5	ymm23
+# define YMM6	ymm24
+# define YMM7	ymm25
+# define YMM8	ymm26
 
-# define VEC_SIZE	32
 
-	.section .text.evex,"ax",@progbits
-ENTRY (STRRCHR)
-	movl	%edi, %ecx
+# define VEC_SIZE	32
+# define PAGE_SIZE	4096
+	.section .text.evex, "ax", @progbits
+ENTRY(STRRCHR)
+	movl	%edi, %eax
 	/* Broadcast CHAR to YMMMATCH.  */
 	VPBROADCAST %esi, %YMMMATCH
 
-	vpxorq	%XMMZERO, %XMMZERO, %XMMZERO
-
-	/* Check if we may cross page boundary with one vector load.  */
-	andl	$(2 * VEC_SIZE - 1), %ecx
-	cmpl	$VEC_SIZE, %ecx
-	ja	L(cros_page_boundary)
+	andl	$(PAGE_SIZE - 1), %eax
+	cmpl	$(PAGE_SIZE - VEC_SIZE), %eax
+	jg	L(cross_page_boundary)
 
+L(page_cross_continue):
 	VMOVU	(%rdi), %YMM1
-
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	/* k0 has a 1 for each zero CHAR in YMM1.  */
+	VPTESTN	%YMM1, %YMM1, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-
-	addq	$VEC_SIZE, %rdi
-
-	testl	%eax, %eax
-	jnz	L(first_vec)
-
 	testl	%ecx, %ecx
-	jnz	L(return_null)
-
-	andq	$-VEC_SIZE, %rdi
-	xorl	%edx, %edx
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(first_vec):
-	/* Check if there is a null byte.  */
-	testl	%ecx, %ecx
-	jnz	L(char_and_nul_in_first_vec)
-
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
-	movq	%rdi, %rsi
-	andq	$-VEC_SIZE, %rdi
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(cros_page_boundary):
-	andl	$(VEC_SIZE - 1), %ecx
-	andq	$-VEC_SIZE, %rdi
+	jz	L(aligned_more)
+	/* fallthrough: zero CHAR in first VEC.  */
 
+	/* K1 has a 1 for each search CHAR match in YMM1.  */
+	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	kmovd	%k1, %eax
+	/* Build mask up until first zero CHAR (used to mask of
+	   potential search CHAR matches past the end of the string).
+	 */
+	blsmskl	%ecx, %ecx
+	andl	%ecx, %eax
+	jz	L(ret0)
+	/* Get last match (the `andl` removed any out of bounds
+	   matches).  */
+	bsrl	%eax, %eax
 # ifdef USE_AS_WCSRCHR
-	/* NB: Divide shift count by 4 since each bit in K1 represent 4
-	   bytes.  */
-	movl	%ecx, %SHIFT_REG
-	sarl	$2, %SHIFT_REG
+	leaq	(%rdi, %rax, CHAR_SIZE), %rax
+# else
+	addq	%rdi, %rax
 # endif
+L(ret0):
+	ret
 
-	VMOVA	(%rdi), %YMM1
-
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
+	/* Returns for first vec x1/x2/x3 have hard coded backward
+	   search path for earlier matches.  */
+	.p2align 4,, 6
+L(first_vec_x1):
+	VPCMP	$0, %YMMMATCH, %YMM2, %k1
+	kmovd	%k1, %eax
+	blsmskl	%ecx, %ecx
+	/* eax non-zero if search CHAR in range.  */
+	andl	%ecx, %eax
+	jnz	L(first_vec_x1_return)
+
+	/* fallthrough: no match in YMM2 then need to check for earlier
+	   matches (in YMM1).  */
+	.p2align 4,, 4
+L(first_vec_x0_test):
 	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %edx
 	kmovd	%k1, %eax
-
-	shrxl	%SHIFT_REG, %edx, %edx
-	shrxl	%SHIFT_REG, %eax, %eax
-	addq	$VEC_SIZE, %rdi
-
-	/* Check if there is a CHAR.  */
 	testl	%eax, %eax
-	jnz	L(found_char)
-
-	testl	%edx, %edx
-	jnz	L(return_null)
-
-	jmp	L(aligned_loop)
-
-	.p2align 4
-L(found_char):
-	testl	%edx, %edx
-	jnz	L(char_and_nul)
-
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
-	leaq	(%rdi, %rcx), %rsi
+	jz	L(ret1)
+	bsrl	%eax, %eax
+# ifdef USE_AS_WCSRCHR
+	leaq	(%rsi, %rax, CHAR_SIZE), %rax
+# else
+	addq	%rsi, %rax
+# endif
+L(ret1):
+	ret
 
-	.p2align 4
-L(aligned_loop):
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
+	.p2align 4,, 10
+L(first_vec_x1_or_x2):
+	VPCMP	$0, %YMM3, %YMMMATCH, %k3
+	VPCMP	$0, %YMM2, %YMMMATCH, %k2
+	/* K2 and K3 have 1 for any search CHAR match. Test if any
+	   matches between either of them. Otherwise check YMM1.  */
+	kortestd %k2, %k3
+	jz	L(first_vec_x0_test)
+
+	/* Guranteed that YMM2 and YMM3 are within range so merge the
+	   two bitmasks then get last result.  */
+	kunpck	%k2, %k3, %k3
+	kmovq	%k3, %rax
+	bsrq	%rax, %rax
+	leaq	(VEC_SIZE)(%r8, %rax, CHAR_SIZE), %rax
+	ret
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %ecx
+	.p2align 4,, 6
+L(first_vec_x3):
+	VPCMP	$0, %YMMMATCH, %YMM4, %k1
 	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	blsmskl	%ecx, %ecx
+	/* If no search CHAR match in range check YMM1/YMM2/YMM3.  */
+	andl	%ecx, %eax
+	jz	L(first_vec_x1_or_x2)
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE * 3)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	VMOVA	(%rdi), %YMM1
-	add	$VEC_SIZE, %rdi
+	.p2align 4,, 6
+L(first_vec_x0_x1_test):
+	VPCMP	$0, %YMMMATCH, %YMM2, %k1
+	kmovd	%k1, %eax
+	/* Check YMM2 for last match first. If no match try YMM1.  */
+	testl	%eax, %eax
+	jz	L(first_vec_x0_test)
+	.p2align 4,, 4
+L(first_vec_x1_return):
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
-	kmovd	%k0, %ecx
+	.p2align 4,, 10
+L(first_vec_x2):
+	VPCMP	$0, %YMMMATCH, %YMM3, %k1
 	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	blsmskl	%ecx, %ecx
+	/* Check YMM3 for last match first. If no match try YMM2/YMM1.
+	 */
+	andl	%ecx, %eax
+	jz	L(first_vec_x0_x1_test)
+	bsrl	%eax, %eax
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
 
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	.p2align 4
+L(aligned_more):
+	/* Need to keep original pointer incase YMM1 has last match.  */
+	movq	%rdi, %rsi
+	andq	$-VEC_SIZE, %rdi
+	VMOVU	VEC_SIZE(%rdi), %YMM2
+	VPTESTN	%YMM2, %YMM2, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jnz	L(char_nor_null)
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x1)
 
-	VMOVA	(%rdi), %YMM1
-	addq	$VEC_SIZE, %rdi
+	VMOVU	(VEC_SIZE * 2)(%rdi), %YMM3
+	VPTESTN	%YMM3, %YMM3, %k0
+	kmovd	%k0, %ecx
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x2)
 
-	/* Each bit in K0 represents a null byte in YMM1.  */
-	VPCMP	$0, %YMMZERO, %YMM1, %k0
-	/* Each bit in K1 represents a CHAR in YMM1.  */
-	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	VMOVU	(VEC_SIZE * 3)(%rdi), %YMM4
+	VPTESTN	%YMM4, %YMM4, %k0
 	kmovd	%k0, %ecx
-	kmovd	%k1, %eax
-	orl	%eax, %ecx
-	jz	L(aligned_loop)
+	movq	%rdi, %r8
+	testl	%ecx, %ecx
+	jnz	L(first_vec_x3)
 
+	andq	$-(VEC_SIZE * 2), %rdi
 	.p2align 4
-L(char_nor_null):
-	/* Find a CHAR or a null byte in a loop.  */
+L(first_aligned_loop):
+	/* Preserve YMM1, YMM2, YMM3, and YMM4 until we can gurantee
+	   they don't store a match.  */
+	VMOVA	(VEC_SIZE * 4)(%rdi), %YMM5
+	VMOVA	(VEC_SIZE * 5)(%rdi), %YMM6
+
+	VPCMP	$0, %YMM5, %YMMMATCH, %k2
+	vpxord	%YMM6, %YMMMATCH, %YMM7
+
+	VPMIN	%YMM5, %YMM6, %YMM8
+	VPMIN	%YMM8, %YMM7, %YMM7
+
+	VPTESTN	%YMM7, %YMM7, %k1
+	subq	$(VEC_SIZE * -2), %rdi
+	kortestd %k1, %k2
+	jz	L(first_aligned_loop)
+
+	VPCMP	$0, %YMM6, %YMMMATCH, %k3
+	VPTESTN	%YMM8, %YMM8, %k1
+	ktestd	%k1, %k1
+	jz	L(second_aligned_loop_prep)
+
+	kortestd %k2, %k3
+	jnz	L(return_first_aligned_loop)
+
+	.p2align 4,, 6
+L(first_vec_x1_or_x2_or_x3):
+	VPCMP	$0, %YMM4, %YMMMATCH, %k4
+	kmovd	%k4, %eax
 	testl	%eax, %eax
-	jnz	L(match)
-L(return_value):
-	testl	%edx, %edx
-	jz	L(return_null)
-	movl	%edx, %eax
-	movq	%rsi, %rdi
+	jz	L(first_vec_x1_or_x2)
 	bsrl	%eax, %eax
-# ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
-# else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
-# endif
+	leaq	(VEC_SIZE * 3)(%r8, %rax, CHAR_SIZE), %rax
 	ret
 
-	.p2align 4
-L(match):
-	/* Find a CHAR.  Check if there is a null byte.  */
-	kmovd	%k0, %ecx
-	testl	%ecx, %ecx
-	jnz	L(find_nul)
+	.p2align 4,, 8
+L(return_first_aligned_loop):
+	VPTESTN	%YMM5, %YMM5, %k0
+	kunpck	%k0, %k1, %k0
+	kmov_2x	%k0, %maskz_2x
+
+	blsmsk	%maskz_2x, %maskz_2x
+	kunpck	%k2, %k3, %k3
+	kmov_2x	%k3, %maskm_2x
+	and	%maskz_2x, %maskm_2x
+	jz	L(first_vec_x1_or_x2_or_x3)
 
-	/* Remember the match and keep searching.  */
-	movl	%eax, %edx
+	bsr	%maskm_2x, %maskm_2x
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
+
+	.p2align 4
+	/* We can throw away the work done for the first 4x checks here
+	   as we have a later match. This is the 'fast' path persay.
+	 */
+L(second_aligned_loop_prep):
+L(second_aligned_loop_set_furthest_match):
 	movq	%rdi, %rsi
-	jmp	L(aligned_loop)
+	kunpck	%k2, %k3, %k4
 
 	.p2align 4
-L(find_nul):
-	/* Mask out any matching bits after the null byte.  */
-	movl	%ecx, %r8d
-	subl	$1, %r8d
-	xorl	%ecx, %r8d
-	andl	%r8d, %eax
-	testl	%eax, %eax
-	/* If there is no CHAR here, return the remembered one.  */
-	jz	L(return_value)
-	bsrl	%eax, %eax
+L(second_aligned_loop):
+	VMOVU	(VEC_SIZE * 4)(%rdi), %YMM1
+	VMOVU	(VEC_SIZE * 5)(%rdi), %YMM2
+
+	VPCMP	$0, %YMM1, %YMMMATCH, %k2
+	vpxord	%YMM2, %YMMMATCH, %YMM3
+
+	VPMIN	%YMM1, %YMM2, %YMM4
+	VPMIN	%YMM3, %YMM4, %YMM3
+
+	VPTESTN	%YMM3, %YMM3, %k1
+	subq	$(VEC_SIZE * -2), %rdi
+	kortestd %k1, %k2
+	jz	L(second_aligned_loop)
+
+	VPCMP	$0, %YMM2, %YMMMATCH, %k3
+	VPTESTN	%YMM4, %YMM4, %k1
+	ktestd	%k1, %k1
+	jz	L(second_aligned_loop_set_furthest_match)
+
+	kortestd %k2, %k3
+	/* branch here because there is a significant advantage interms
+	   of output dependency chance in using edx.  */
+	jnz	L(return_new_match)
+L(return_old_match):
+	kmovq	%k4, %rax
+	bsrq	%rax, %rax
+	leaq	(VEC_SIZE * 2)(%rsi, %rax, CHAR_SIZE), %rax
+	ret
+
+L(return_new_match):
+	VPTESTN	%YMM1, %YMM1, %k0
+	kunpck	%k0, %k1, %k0
+	kmov_2x	%k0, %maskz_2x
+
+	blsmsk	%maskz_2x, %maskz_2x
+	kunpck	%k2, %k3, %k3
+	kmov_2x	%k3, %maskm_2x
+	and	%maskz_2x, %maskm_2x
+	jz	L(return_old_match)
+
+	bsr	%maskm_2x, %maskm_2x
+	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
+	ret
+
+L(cross_page_boundary):
+	/* eax contains all the page offset bits of src (rdi). `xor rdi,
+	   rax` sets pointer will all page offset bits cleared so
+	   offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
+	   before page cross (guranteed to be safe to read). Doing this
+	   as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
+	   a bit of code size.  */
+	xorq	%rdi, %rax
+	VMOVU	(PAGE_SIZE - VEC_SIZE)(%rax), %YMM1
+	VPTESTN	%YMM1, %YMM1, %k0
+	kmovd	%k0, %ecx
+
+	/* Shift out zero CHAR matches that are before the begining of
+	   src (rdi).  */
 # ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
-# else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
+	movl	%edi, %esi
+	andl	$(VEC_SIZE - 1), %esi
+	shrl	$2, %esi
 # endif
-	ret
+	shrxl	%SHIFT_REG, %ecx, %ecx
 
-	.p2align 4
-L(char_and_nul):
-	/* Find both a CHAR and a null byte.  */
-	addq	%rcx, %rdi
-	movl	%edx, %ecx
-L(char_and_nul_in_first_vec):
-	/* Mask out any matching bits after the null byte.  */
-	movl	%ecx, %r8d
-	subl	$1, %r8d
-	xorl	%ecx, %r8d
-	andl	%r8d, %eax
-	testl	%eax, %eax
-	/* Return null pointer if the null byte comes first.  */
-	jz	L(return_null)
+	testl	%ecx, %ecx
+	jz	L(page_cross_continue)
+
+	/* Found zero CHAR so need to test for search CHAR.  */
+	VPCMP	$0, %YMMMATCH, %YMM1, %k1
+	kmovd	%k1, %eax
+	/* Shift out search CHAR matches that are before the begining of
+	   src (rdi).  */
+	shrxl	%SHIFT_REG, %eax, %eax
+
+	/* Check if any search CHAR match in range.  */
+	blsmskl	%ecx, %ecx
+	andl	%ecx, %eax
+	jz	L(ret3)
 	bsrl	%eax, %eax
 # ifdef USE_AS_WCSRCHR
-	/* NB: Multiply wchar_t count by 4 to get the number of bytes.  */
-	leaq	-VEC_SIZE(%rdi, %rax, 4), %rax
+	leaq	(%rdi, %rax, CHAR_SIZE), %rax
 # else
-	leaq	-VEC_SIZE(%rdi, %rax), %rax
+	addq	%rdi, %rax
 # endif
+L(ret3):
 	ret
 
-	.p2align 4
-L(return_null):
-	xorl	%eax, %eax
-	ret
-
-END (STRRCHR)
+END(STRRCHR)
 #endif
-- 
GitLab