Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

bpf: specify the old and new poke_type for bpf_arch_text_poke

In the origin logic, the bpf_arch_text_poke() assume that the old and new
instructions have the same opcode. However, they can have different opcode
if we want to replace a "call" insn with a "jmp" insn.

Therefore, add the new function parameter "old_t" along with the "new_t",
which are used to indicate the old and new poke type. Meanwhile, adjust
the implement of bpf_arch_text_poke() for all the archs.

"BPF_MOD_NOP" is added to make the code more readable. In
bpf_arch_text_poke(), we still check if the new and old address is NULL to
determine if nop insn should be used, which I think is more safe.

Signed-off-by: Menglong Dong <dongml2@chinatelecom.cn>
Link: https://lore.kernel.org/r/20251118123639.688444-6-dongml2@chinatelecom.cn
Signed-off-by: Alexei Starovoitov <ast@kernel.org>

authored by

Menglong Dong and committed by
Alexei Starovoitov
ae4a3160 373f2f44

+71 -46
+7 -7
arch/arm64/net/bpf_jit_comp.c
··· 2934 2934 * The dummy_tramp is used to prevent another CPU from jumping to unknown 2935 2935 * locations during the patching process, making the patching process easier. 2936 2936 */ 2937 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type, 2938 - void *old_addr, void *new_addr) 2937 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 2938 + enum bpf_text_poke_type new_t, void *old_addr, 2939 + void *new_addr) 2939 2940 { 2940 2941 int ret; 2941 2942 u32 old_insn; ··· 2980 2979 !poking_bpf_entry)) 2981 2980 return -EINVAL; 2982 2981 2983 - if (poke_type == BPF_MOD_CALL) 2984 - branch_type = AARCH64_INSN_BRANCH_LINK; 2985 - else 2986 - branch_type = AARCH64_INSN_BRANCH_NOLINK; 2987 - 2982 + branch_type = old_t == BPF_MOD_CALL ? AARCH64_INSN_BRANCH_LINK : 2983 + AARCH64_INSN_BRANCH_NOLINK; 2988 2984 if (gen_branch_or_nop(branch_type, ip, old_addr, plt, &old_insn) < 0) 2989 2985 return -EFAULT; 2990 2986 2987 + branch_type = new_t == BPF_MOD_CALL ? AARCH64_INSN_BRANCH_LINK : 2988 + AARCH64_INSN_BRANCH_NOLINK; 2991 2989 if (gen_branch_or_nop(branch_type, ip, new_addr, plt, &new_insn) < 0) 2992 2990 return -EFAULT; 2993 2991
+6 -3
arch/loongarch/net/bpf_jit.c
··· 1284 1284 return ret ? ERR_PTR(-EINVAL) : dst; 1285 1285 } 1286 1286 1287 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type, 1288 - void *old_addr, void *new_addr) 1287 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 1288 + enum bpf_text_poke_type new_t, void *old_addr, 1289 + void *new_addr) 1289 1290 { 1290 1291 int ret; 1291 - bool is_call = (poke_type == BPF_MOD_CALL); 1292 + bool is_call; 1292 1293 u32 old_insns[LOONGARCH_LONG_JUMP_NINSNS] = {[0 ... 4] = INSN_NOP}; 1293 1294 u32 new_insns[LOONGARCH_LONG_JUMP_NINSNS] = {[0 ... 4] = INSN_NOP}; 1294 1295 ··· 1299 1298 if (!is_bpf_text_address((unsigned long)ip)) 1300 1299 return -ENOTSUPP; 1301 1300 1301 + is_call = old_t == BPF_MOD_CALL; 1302 1302 ret = emit_jump_or_nops(old_addr, ip, old_insns, is_call); 1303 1303 if (ret) 1304 1304 return ret; ··· 1307 1305 if (memcmp(ip, old_insns, LOONGARCH_LONG_JUMP_NBYTES)) 1308 1306 return -EFAULT; 1309 1307 1308 + is_call = new_t == BPF_MOD_CALL; 1310 1309 ret = emit_jump_or_nops(new_addr, ip, new_insns, is_call); 1311 1310 if (ret) 1312 1311 return ret;
+6 -4
arch/powerpc/net/bpf_jit_comp.c
··· 1107 1107 * execute isync (or some CSI) so that they don't go back into the 1108 1108 * trampoline again. 1109 1109 */ 1110 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type, 1111 - void *old_addr, void *new_addr) 1110 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 1111 + enum bpf_text_poke_type new_t, void *old_addr, 1112 + void *new_addr) 1112 1113 { 1113 1114 unsigned long bpf_func, bpf_func_end, size, offset; 1114 1115 ppc_inst_t old_inst, new_inst; ··· 1120 1119 return -EOPNOTSUPP; 1121 1120 1122 1121 bpf_func = (unsigned long)ip; 1123 - branch_flags = poke_type == BPF_MOD_CALL ? BRANCH_SET_LINK : 0; 1124 1122 1125 1123 /* We currently only support poking bpf programs */ 1126 1124 if (!__bpf_address_lookup(bpf_func, &size, &offset, name)) { ··· 1132 1132 * an unconditional branch instruction at im->ip_after_call 1133 1133 */ 1134 1134 if (offset) { 1135 - if (poke_type != BPF_MOD_JUMP) { 1135 + if (old_t == BPF_MOD_CALL || new_t == BPF_MOD_CALL) { 1136 1136 pr_err("%s (0x%lx): calls are not supported in bpf prog body\n", __func__, 1137 1137 bpf_func); 1138 1138 return -EOPNOTSUPP; ··· 1166 1166 } 1167 1167 1168 1168 old_inst = ppc_inst(PPC_RAW_NOP()); 1169 + branch_flags = old_t == BPF_MOD_CALL ? BRANCH_SET_LINK : 0; 1169 1170 if (old_addr) { 1170 1171 if (is_offset_in_branch_range(ip - old_addr)) 1171 1172 create_branch(&old_inst, ip, (unsigned long)old_addr, branch_flags); ··· 1175 1174 branch_flags); 1176 1175 } 1177 1176 new_inst = ppc_inst(PPC_RAW_NOP()); 1177 + branch_flags = new_t == BPF_MOD_CALL ? BRANCH_SET_LINK : 0; 1178 1178 if (new_addr) { 1179 1179 if (is_offset_in_branch_range(ip - new_addr)) 1180 1180 create_branch(&new_inst, ip, (unsigned long)new_addr, branch_flags);
+6 -3
arch/riscv/net/bpf_jit_comp64.c
··· 852 852 return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx); 853 853 } 854 854 855 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type, 856 - void *old_addr, void *new_addr) 855 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 856 + enum bpf_text_poke_type new_t, void *old_addr, 857 + void *new_addr) 857 858 { 858 859 u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS]; 859 - bool is_call = poke_type == BPF_MOD_CALL; 860 + bool is_call; 860 861 int ret; 861 862 862 863 if (!is_kernel_text((unsigned long)ip) && 863 864 !is_bpf_text_address((unsigned long)ip)) 864 865 return -ENOTSUPP; 865 866 867 + is_call = old_t == BPF_MOD_CALL; 866 868 ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call); 867 869 if (ret) 868 870 return ret; ··· 872 870 if (memcmp(ip, old_insns, RV_FENTRY_NBYTES)) 873 871 return -EFAULT; 874 872 873 + is_call = new_t == BPF_MOD_CALL; 875 874 ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call); 876 875 if (ret) 877 876 return ret;
+4 -3
arch/s390/net/bpf_jit_comp.c
··· 2413 2413 return true; 2414 2414 } 2415 2415 2416 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, 2417 - void *old_addr, void *new_addr) 2416 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 2417 + enum bpf_text_poke_type new_t, void *old_addr, 2418 + void *new_addr) 2418 2419 { 2419 2420 struct bpf_plt expected_plt, current_plt, new_plt, *plt; 2420 2421 struct { ··· 2432 2431 if (insn.opc != (0xc004 | (old_addr ? 0xf0 : 0))) 2433 2432 return -EINVAL; 2434 2433 2435 - if (t == BPF_MOD_JUMP && 2434 + if ((new_t == BPF_MOD_JUMP || old_t == BPF_MOD_JUMP) && 2436 2435 insn.disp == ((char *)new_addr - (char *)ip) >> 1) { 2437 2436 /* 2438 2437 * The branch already points to the destination,
+21 -16
arch/x86/net/bpf_jit_comp.c
··· 597 597 return emit_patch(pprog, func, ip, 0xE9); 598 598 } 599 599 600 - static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, 600 + static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 601 + enum bpf_text_poke_type new_t, 601 602 void *old_addr, void *new_addr) 602 603 { 603 604 const u8 *nop_insn = x86_nops[5]; ··· 608 607 int ret; 609 608 610 609 memcpy(old_insn, nop_insn, X86_PATCH_SIZE); 611 - if (old_addr) { 610 + if (old_t != BPF_MOD_NOP && old_addr) { 612 611 prog = old_insn; 613 - ret = t == BPF_MOD_CALL ? 612 + ret = old_t == BPF_MOD_CALL ? 614 613 emit_call(&prog, old_addr, ip) : 615 614 emit_jump(&prog, old_addr, ip); 616 615 if (ret) ··· 618 617 } 619 618 620 619 memcpy(new_insn, nop_insn, X86_PATCH_SIZE); 621 - if (new_addr) { 620 + if (new_t != BPF_MOD_NOP && new_addr) { 622 621 prog = new_insn; 623 - ret = t == BPF_MOD_CALL ? 622 + ret = new_t == BPF_MOD_CALL ? 624 623 emit_call(&prog, new_addr, ip) : 625 624 emit_jump(&prog, new_addr, ip); 626 625 if (ret) ··· 641 640 return ret; 642 641 } 643 642 644 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, 645 - void *old_addr, void *new_addr) 643 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 644 + enum bpf_text_poke_type new_t, void *old_addr, 645 + void *new_addr) 646 646 { 647 647 if (!is_kernel_text((long)ip) && 648 648 !is_bpf_text_address((long)ip)) ··· 657 655 if (is_endbr(ip)) 658 656 ip += ENDBR_INSN_SIZE; 659 657 660 - return __bpf_arch_text_poke(ip, t, old_addr, new_addr); 658 + return __bpf_arch_text_poke(ip, old_t, new_t, old_addr, new_addr); 661 659 } 662 660 663 661 #define EMIT_LFENCE() EMIT3(0x0F, 0xAE, 0xE8) ··· 899 897 target = array->ptrs[poke->tail_call.key]; 900 898 if (target) { 901 899 ret = __bpf_arch_text_poke(poke->tailcall_target, 902 - BPF_MOD_JUMP, NULL, 900 + BPF_MOD_NOP, BPF_MOD_JUMP, 901 + NULL, 903 902 (u8 *)target->bpf_func + 904 903 poke->adj_off); 905 904 BUG_ON(ret < 0); 906 905 ret = __bpf_arch_text_poke(poke->tailcall_bypass, 907 - BPF_MOD_JUMP, 906 + BPF_MOD_JUMP, BPF_MOD_NOP, 908 907 (u8 *)poke->tailcall_target + 909 908 X86_PATCH_SIZE, NULL); 910 909 BUG_ON(ret < 0); ··· 3988 3985 struct bpf_prog *new, struct bpf_prog *old) 3989 3986 { 3990 3987 u8 *old_addr, *new_addr, *old_bypass_addr; 3988 + enum bpf_text_poke_type t; 3991 3989 int ret; 3992 3990 3993 3991 old_bypass_addr = old ? NULL : poke->bypass_addr; ··· 4001 3997 * the kallsyms check. 4002 3998 */ 4003 3999 if (new) { 4000 + t = old_addr ? BPF_MOD_JUMP : BPF_MOD_NOP; 4004 4001 ret = __bpf_arch_text_poke(poke->tailcall_target, 4005 - BPF_MOD_JUMP, 4002 + t, BPF_MOD_JUMP, 4006 4003 old_addr, new_addr); 4007 4004 BUG_ON(ret < 0); 4008 4005 if (!old) { 4009 4006 ret = __bpf_arch_text_poke(poke->tailcall_bypass, 4010 - BPF_MOD_JUMP, 4007 + BPF_MOD_JUMP, BPF_MOD_NOP, 4011 4008 poke->bypass_addr, 4012 4009 NULL); 4013 4010 BUG_ON(ret < 0); 4014 4011 } 4015 4012 } else { 4013 + t = old_bypass_addr ? BPF_MOD_JUMP : BPF_MOD_NOP; 4016 4014 ret = __bpf_arch_text_poke(poke->tailcall_bypass, 4017 - BPF_MOD_JUMP, 4018 - old_bypass_addr, 4015 + t, BPF_MOD_JUMP, old_bypass_addr, 4019 4016 poke->bypass_addr); 4020 4017 BUG_ON(ret < 0); 4021 4018 /* let other CPUs finish the execution of program ··· 4025 4020 */ 4026 4021 if (!ret) 4027 4022 synchronize_rcu(); 4023 + t = old_addr ? BPF_MOD_JUMP : BPF_MOD_NOP; 4028 4024 ret = __bpf_arch_text_poke(poke->tailcall_target, 4029 - BPF_MOD_JUMP, 4030 - old_addr, NULL); 4025 + t, BPF_MOD_NOP, old_addr, NULL); 4031 4026 BUG_ON(ret < 0); 4032 4027 } 4033 4028 }
+4 -2
include/linux/bpf.h
··· 3710 3710 #endif /* CONFIG_INET */ 3711 3711 3712 3712 enum bpf_text_poke_type { 3713 + BPF_MOD_NOP, 3713 3714 BPF_MOD_CALL, 3714 3715 BPF_MOD_JUMP, 3715 3716 }; 3716 3717 3717 - int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, 3718 - void *addr1, void *addr2); 3718 + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 3719 + enum bpf_text_poke_type new_t, void *old_addr, 3720 + void *new_addr); 3719 3721 3720 3722 void bpf_arch_poke_desc_update(struct bpf_jit_poke_descriptor *poke, 3721 3723 struct bpf_prog *new, struct bpf_prog *old);
+3 -2
kernel/bpf/core.c
··· 3150 3150 return -EFAULT; 3151 3151 } 3152 3152 3153 - int __weak bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, 3154 - void *addr1, void *addr2) 3153 + int __weak bpf_arch_text_poke(void *ip, enum bpf_text_poke_type old_t, 3154 + enum bpf_text_poke_type new_t, void *old_addr, 3155 + void *new_addr) 3155 3156 { 3156 3157 return -ENOTSUPP; 3157 3158 }
+14 -6
kernel/bpf/trampoline.c
··· 183 183 if (tr->func.ftrace_managed) 184 184 ret = unregister_ftrace_direct(tr->fops, (long)old_addr, false); 185 185 else 186 - ret = bpf_arch_text_poke(ip, BPF_MOD_CALL, old_addr, NULL); 186 + ret = bpf_arch_text_poke(ip, BPF_MOD_CALL, BPF_MOD_NOP, 187 + old_addr, NULL); 187 188 188 189 return ret; 189 190 } ··· 201 200 else 202 201 ret = modify_ftrace_direct_nolock(tr->fops, (long)new_addr); 203 202 } else { 204 - ret = bpf_arch_text_poke(ip, BPF_MOD_CALL, old_addr, new_addr); 203 + ret = bpf_arch_text_poke(ip, 204 + old_addr ? BPF_MOD_CALL : BPF_MOD_NOP, 205 + new_addr ? BPF_MOD_CALL : BPF_MOD_NOP, 206 + old_addr, new_addr); 205 207 } 206 208 return ret; 207 209 } ··· 229 225 return ret; 230 226 ret = register_ftrace_direct(tr->fops, (long)new_addr); 231 227 } else { 232 - ret = bpf_arch_text_poke(ip, BPF_MOD_CALL, NULL, new_addr); 228 + ret = bpf_arch_text_poke(ip, BPF_MOD_NOP, BPF_MOD_CALL, 229 + NULL, new_addr); 233 230 } 234 231 235 232 return ret; ··· 341 336 * call_rcu_tasks() is not necessary. 342 337 */ 343 338 if (im->ip_after_call) { 344 - int err = bpf_arch_text_poke(im->ip_after_call, BPF_MOD_JUMP, 345 - NULL, im->ip_epilogue); 339 + int err = bpf_arch_text_poke(im->ip_after_call, BPF_MOD_NOP, 340 + BPF_MOD_JUMP, NULL, 341 + im->ip_epilogue); 346 342 WARN_ON(err); 347 343 if (IS_ENABLED(CONFIG_TASKS_RCU)) 348 344 call_rcu_tasks(&im->rcu, __bpf_tramp_image_put_rcu_tasks); ··· 576 570 if (err) 577 571 return err; 578 572 tr->extension_prog = link->link.prog; 579 - return bpf_arch_text_poke(tr->func.addr, BPF_MOD_JUMP, NULL, 573 + return bpf_arch_text_poke(tr->func.addr, BPF_MOD_NOP, 574 + BPF_MOD_JUMP, NULL, 580 575 link->link.prog->bpf_func); 581 576 } 582 577 if (cnt >= BPF_MAX_TRAMP_LINKS) ··· 625 618 if (kind == BPF_TRAMP_REPLACE) { 626 619 WARN_ON_ONCE(!tr->extension_prog); 627 620 err = bpf_arch_text_poke(tr->func.addr, BPF_MOD_JUMP, 621 + BPF_MOD_NOP, 628 622 tr->extension_prog->bpf_func, NULL); 629 623 tr->extension_prog = NULL; 630 624 guard(mutex)(&tgt_prog->aux->ext_mutex);