Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

bpf, x32: Fix bug with ALU64 {LSH, RSH, ARSH} BPF_X shift by 0

The current x32 BPF JIT for shift operations is not correct when the
shift amount in a register is 0. The expected behavior is a no-op, whereas
the current implementation changes bits in the destination register.

The following example demonstrates the bug. The expected result of this
program is 1, but the current JITed code returns 2.

r0 = 1
r1 = 1
r2 = 0
r1 <<= r2
if r1 == 1 goto end
r0 = 2
end:
exit

The bug is caused by an incorrect assumption by the JIT that a shift by
32 clear the register. On x32 however, shifts use the lower 5 bits of
the source, making a shift by 32 equivalent to a shift by 0.

This patch fixes the bug using double-precision shifts, which also
simplifies the code.

Fixes: 03f5781be2c7 ("bpf, x86_32: add eBPF JIT compiler for ia32")
Co-developed-by: Xi Wang <xi.wang@gmail.com>
Signed-off-by: Xi Wang <xi.wang@gmail.com>
Signed-off-by: Luke Nelson <luke.r.nels@gmail.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>

authored by

Luke Nelson and committed by
Daniel Borkmann
68a8357e 0472301a

+28 -203
+28 -203
arch/x86/net/bpf_jit_comp32.c
··· 724 724 { 725 725 u8 *prog = *pprog; 726 726 int cnt = 0; 727 - static int jmp_label1 = -1; 728 - static int jmp_label2 = -1; 729 - static int jmp_label3 = -1; 730 727 u8 dreg_lo = dstk ? IA32_EAX : dst_lo; 731 728 u8 dreg_hi = dstk ? IA32_EDX : dst_hi; 732 729 ··· 742 745 /* mov ecx,src_lo */ 743 746 EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); 744 747 748 + /* shld dreg_hi,dreg_lo,cl */ 749 + EMIT3(0x0F, 0xA5, add_2reg(0xC0, dreg_hi, dreg_lo)); 750 + /* shl dreg_lo,cl */ 751 + EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); 752 + 753 + /* if ecx >= 32, mov dreg_lo into dreg_hi and clear dreg_lo */ 754 + 745 755 /* cmp ecx,32 */ 746 756 EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); 747 - /* Jumps when >= 32 */ 748 - if (is_imm8(jmp_label(jmp_label1, 2))) 749 - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); 750 - else 751 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); 757 + /* skip the next two instructions (4 bytes) when < 32 */ 758 + EMIT2(IA32_JB, 4); 752 759 753 - /* < 32 */ 754 - /* shl dreg_hi,cl */ 755 - EMIT2(0xD3, add_1reg(0xE0, dreg_hi)); 756 - /* mov ebx,dreg_lo */ 757 - EMIT2(0x8B, add_2reg(0xC0, dreg_lo, IA32_EBX)); 758 - /* shl dreg_lo,cl */ 759 - EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); 760 - 761 - /* IA32_ECX = -IA32_ECX + 32 */ 762 - /* neg ecx */ 763 - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); 764 - /* add ecx,32 */ 765 - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); 766 - 767 - /* shr ebx,cl */ 768 - EMIT2(0xD3, add_1reg(0xE8, IA32_EBX)); 769 - /* or dreg_hi,ebx */ 770 - EMIT2(0x09, add_2reg(0xC0, dreg_hi, IA32_EBX)); 771 - 772 - /* goto out; */ 773 - if (is_imm8(jmp_label(jmp_label3, 2))) 774 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 775 - else 776 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 777 - 778 - /* >= 32 */ 779 - if (jmp_label1 == -1) 780 - jmp_label1 = cnt; 781 - 782 - /* cmp ecx,64 */ 783 - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); 784 - /* Jumps when >= 64 */ 785 - if (is_imm8(jmp_label(jmp_label2, 2))) 786 - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); 787 - else 788 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); 789 - 790 - /* >= 32 && < 64 */ 791 - /* sub ecx,32 */ 792 - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); 793 - /* shl dreg_lo,cl */ 794 - EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); 795 760 /* mov dreg_hi,dreg_lo */ 796 761 EMIT2(0x89, add_2reg(0xC0, dreg_hi, dreg_lo)); 797 - 798 762 /* xor dreg_lo,dreg_lo */ 799 763 EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); 800 - 801 - /* goto out; */ 802 - if (is_imm8(jmp_label(jmp_label3, 2))) 803 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 804 - else 805 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 806 - 807 - /* >= 64 */ 808 - if (jmp_label2 == -1) 809 - jmp_label2 = cnt; 810 - /* xor dreg_lo,dreg_lo */ 811 - EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); 812 - /* xor dreg_hi,dreg_hi */ 813 - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); 814 - 815 - if (jmp_label3 == -1) 816 - jmp_label3 = cnt; 817 764 818 765 if (dstk) { 819 766 /* mov dword ptr [ebp+off],dreg_lo */ ··· 777 836 { 778 837 u8 *prog = *pprog; 779 838 int cnt = 0; 780 - static int jmp_label1 = -1; 781 - static int jmp_label2 = -1; 782 - static int jmp_label3 = -1; 783 839 u8 dreg_lo = dstk ? IA32_EAX : dst_lo; 784 840 u8 dreg_hi = dstk ? IA32_EDX : dst_hi; 785 841 ··· 795 857 /* mov ecx,src_lo */ 796 858 EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); 797 859 860 + /* shrd dreg_lo,dreg_hi,cl */ 861 + EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi)); 862 + /* sar dreg_hi,cl */ 863 + EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); 864 + 865 + /* if ecx >= 32, mov dreg_hi to dreg_lo and set/clear dreg_hi depending on sign */ 866 + 798 867 /* cmp ecx,32 */ 799 868 EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); 800 - /* Jumps when >= 32 */ 801 - if (is_imm8(jmp_label(jmp_label1, 2))) 802 - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); 803 - else 804 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); 869 + /* skip the next two instructions (5 bytes) when < 32 */ 870 + EMIT2(IA32_JB, 5); 805 871 806 - /* < 32 */ 807 - /* lshr dreg_lo,cl */ 808 - EMIT2(0xD3, add_1reg(0xE8, dreg_lo)); 809 - /* mov ebx,dreg_hi */ 810 - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); 811 - /* ashr dreg_hi,cl */ 812 - EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); 813 - 814 - /* IA32_ECX = -IA32_ECX + 32 */ 815 - /* neg ecx */ 816 - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); 817 - /* add ecx,32 */ 818 - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); 819 - 820 - /* shl ebx,cl */ 821 - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); 822 - /* or dreg_lo,ebx */ 823 - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); 824 - 825 - /* goto out; */ 826 - if (is_imm8(jmp_label(jmp_label3, 2))) 827 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 828 - else 829 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 830 - 831 - /* >= 32 */ 832 - if (jmp_label1 == -1) 833 - jmp_label1 = cnt; 834 - 835 - /* cmp ecx,64 */ 836 - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); 837 - /* Jumps when >= 64 */ 838 - if (is_imm8(jmp_label(jmp_label2, 2))) 839 - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); 840 - else 841 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); 842 - 843 - /* >= 32 && < 64 */ 844 - /* sub ecx,32 */ 845 - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); 846 - /* ashr dreg_hi,cl */ 847 - EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); 848 872 /* mov dreg_lo,dreg_hi */ 849 873 EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); 850 - 851 - /* ashr dreg_hi,imm8 */ 874 + /* sar dreg_hi,31 */ 852 875 EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31); 853 - 854 - /* goto out; */ 855 - if (is_imm8(jmp_label(jmp_label3, 2))) 856 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 857 - else 858 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 859 - 860 - /* >= 64 */ 861 - if (jmp_label2 == -1) 862 - jmp_label2 = cnt; 863 - /* ashr dreg_hi,imm8 */ 864 - EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31); 865 - /* mov dreg_lo,dreg_hi */ 866 - EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); 867 - 868 - if (jmp_label3 == -1) 869 - jmp_label3 = cnt; 870 876 871 877 if (dstk) { 872 878 /* mov dword ptr [ebp+off],dreg_lo */ ··· 830 948 { 831 949 u8 *prog = *pprog; 832 950 int cnt = 0; 833 - static int jmp_label1 = -1; 834 - static int jmp_label2 = -1; 835 - static int jmp_label3 = -1; 836 951 u8 dreg_lo = dstk ? IA32_EAX : dst_lo; 837 952 u8 dreg_hi = dstk ? IA32_EDX : dst_hi; 838 953 ··· 848 969 /* mov ecx,src_lo */ 849 970 EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); 850 971 972 + /* shrd dreg_lo,dreg_hi,cl */ 973 + EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi)); 974 + /* shr dreg_hi,cl */ 975 + EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); 976 + 977 + /* if ecx >= 32, mov dreg_hi to dreg_lo and clear dreg_hi */ 978 + 851 979 /* cmp ecx,32 */ 852 980 EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); 853 - /* Jumps when >= 32 */ 854 - if (is_imm8(jmp_label(jmp_label1, 2))) 855 - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); 856 - else 857 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); 981 + /* skip the next two instructions (4 bytes) when < 32 */ 982 + EMIT2(IA32_JB, 4); 858 983 859 - /* < 32 */ 860 - /* lshr dreg_lo,cl */ 861 - EMIT2(0xD3, add_1reg(0xE8, dreg_lo)); 862 - /* mov ebx,dreg_hi */ 863 - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); 864 - /* shr dreg_hi,cl */ 865 - EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); 866 - 867 - /* IA32_ECX = -IA32_ECX + 32 */ 868 - /* neg ecx */ 869 - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); 870 - /* add ecx,32 */ 871 - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); 872 - 873 - /* shl ebx,cl */ 874 - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); 875 - /* or dreg_lo,ebx */ 876 - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); 877 - 878 - /* goto out; */ 879 - if (is_imm8(jmp_label(jmp_label3, 2))) 880 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 881 - else 882 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 883 - 884 - /* >= 32 */ 885 - if (jmp_label1 == -1) 886 - jmp_label1 = cnt; 887 - /* cmp ecx,64 */ 888 - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); 889 - /* Jumps when >= 64 */ 890 - if (is_imm8(jmp_label(jmp_label2, 2))) 891 - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); 892 - else 893 - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); 894 - 895 - /* >= 32 && < 64 */ 896 - /* sub ecx,32 */ 897 - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); 898 - /* shr dreg_hi,cl */ 899 - EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); 900 984 /* mov dreg_lo,dreg_hi */ 901 985 EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); 902 986 /* xor dreg_hi,dreg_hi */ 903 987 EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); 904 - 905 - /* goto out; */ 906 - if (is_imm8(jmp_label(jmp_label3, 2))) 907 - EMIT2(0xEB, jmp_label(jmp_label3, 2)); 908 - else 909 - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); 910 - 911 - /* >= 64 */ 912 - if (jmp_label2 == -1) 913 - jmp_label2 = cnt; 914 - /* xor dreg_lo,dreg_lo */ 915 - EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); 916 - /* xor dreg_hi,dreg_hi */ 917 - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); 918 - 919 - if (jmp_label3 == -1) 920 - jmp_label3 = cnt; 921 988 922 989 if (dstk) { 923 990 /* mov dword ptr [ebp+off],dreg_lo */