diff --git a/arch/arm64/include/asm/extable.h b/arch/arm64/include/asm/extable.h index 56a4f68b262e..840a35ed92ec 100644 --- a/arch/arm64/include/asm/extable.h +++ b/arch/arm64/include/asm/extable.h @@ -22,5 +22,17 @@ struct exception_table_entry #define ARCH_HAS_RELATIVE_EXTABLE +#ifdef CONFIG_BPF_JIT +int arm64_bpf_fixup_exception(const struct exception_table_entry *ex, + struct pt_regs *regs); +#else /* !CONFIG_BPF_JIT */ +static inline +int arm64_bpf_fixup_exception(const struct exception_table_entry *ex, + struct pt_regs *regs) +{ + return 0; +} +#endif /* !CONFIG_BPF_JIT */ + extern int fixup_exception(struct pt_regs *regs); #endif diff --git a/arch/arm64/mm/extable.c b/arch/arm64/mm/extable.c index 81e694af5f8c..eee1732ab6cd 100644 --- a/arch/arm64/mm/extable.c +++ b/arch/arm64/mm/extable.c @@ -11,8 +11,14 @@ int fixup_exception(struct pt_regs *regs) const struct exception_table_entry *fixup; fixup = search_exception_tables(instruction_pointer(regs)); - if (fixup) - regs->pc = (unsigned long)&fixup->fixup + fixup->fixup; + if (!fixup) + return 0; - return fixup != NULL; + if (IS_ENABLED(CONFIG_BPF_JIT) && + regs->pc >= BPF_JIT_REGION_START && + regs->pc < BPF_JIT_REGION_END) + return arm64_bpf_fixup_exception(fixup, regs); + + regs->pc = (unsigned long)&fixup->fixup + fixup->fixup; + return 1; } diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c index 3cb25b43b368..f8912e45be7a 100644 --- a/arch/arm64/net/bpf_jit_comp.c +++ b/arch/arm64/net/bpf_jit_comp.c @@ -7,6 +7,7 @@ #define pr_fmt(fmt) "bpf_jit: " fmt +#include #include #include #include @@ -56,6 +57,7 @@ struct jit_ctx { int idx; int epilogue_offset; int *offset; + int exentry_idx; __le32 *image; u32 stack_size; }; @@ -351,6 +353,67 @@ static void build_epilogue(struct jit_ctx *ctx) emit(A64_RET(A64_LR), ctx); } +#define BPF_FIXUP_OFFSET_MASK GENMASK(26, 0) +#define BPF_FIXUP_REG_MASK GENMASK(31, 27) + +int arm64_bpf_fixup_exception(const struct exception_table_entry *ex, + struct pt_regs *regs) +{ + off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup); + int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup); + + regs->regs[dst_reg] = 0; + regs->pc = (unsigned long)&ex->fixup - offset; + return 1; +} + +/* For accesses to BTF pointers, add an entry to the exception table */ +static int add_exception_handler(const struct bpf_insn *insn, + struct jit_ctx *ctx, + int dst_reg) +{ + off_t offset; + unsigned long pc; + struct exception_table_entry *ex; + + if (!ctx->image) + /* First pass */ + return 0; + + if (BPF_MODE(insn->code) != BPF_PROBE_MEM) + return 0; + + if (!ctx->prog->aux->extable || + WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries)) + return -EINVAL; + + ex = &ctx->prog->aux->extable[ctx->exentry_idx]; + pc = (unsigned long)&ctx->image[ctx->idx - 1]; + + offset = pc - (long)&ex->insn; + if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN)) + return -ERANGE; + ex->insn = offset; + + /* + * Since the extable follows the program, the fixup offset is always + * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value + * to keep things simple, and put the destination register in the upper + * bits. We don't need to worry about buildtime or runtime sort + * modifying the upper bits because the table is already sorted, and + * isn't part of the main exception table. + */ + offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE); + if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset)) + return -ERANGE; + + ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) | + FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg); + + ctx->exentry_idx++; + return 0; +} + /* JITs an eBPF instruction. * Returns: * 0 - successfully JITed an 8-byte eBPF instruction. @@ -375,6 +438,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, u8 jmp_cond, reg; s32 jmp_offset; u32 a64_insn; + int ret; #define check_imm(bits, imm) do { \ if ((((imm) > 0) && ((imm) >> (bits))) || \ @@ -694,7 +758,6 @@ emit_cond_jmp: const u8 r0 = bpf2a64[BPF_REG_0]; bool func_addr_fixed; u64 func_addr; - int ret; ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &func_addr, &func_addr_fixed); @@ -738,6 +801,10 @@ emit_cond_jmp: case BPF_LDX | BPF_MEM | BPF_H: case BPF_LDX | BPF_MEM | BPF_B: case BPF_LDX | BPF_MEM | BPF_DW: + case BPF_LDX | BPF_PROBE_MEM | BPF_DW: + case BPF_LDX | BPF_PROBE_MEM | BPF_W: + case BPF_LDX | BPF_PROBE_MEM | BPF_H: + case BPF_LDX | BPF_PROBE_MEM | BPF_B: emit_a64_mov_i(1, tmp, off, ctx); switch (BPF_SIZE(code)) { case BPF_W: @@ -753,6 +820,10 @@ emit_cond_jmp: emit(A64_LDR64(dst, src, tmp), ctx); break; } + + ret = add_exception_handler(insn, ctx, dst); + if (ret) + return ret; break; /* ST: *(size *)(dst + off) = imm */ @@ -868,6 +939,9 @@ static int validate_code(struct jit_ctx *ctx) return -1; } + if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries)) + return -1; + return 0; } @@ -884,6 +958,7 @@ struct arm64_jit_data { struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) { + int image_size, prog_size, extable_size; struct bpf_prog *tmp, *orig_prog = prog; struct bpf_binary_header *header; struct arm64_jit_data *jit_data; @@ -891,7 +966,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) bool tmp_blinded = false; bool extra_pass = false; struct jit_ctx ctx; - int image_size; u8 *image_ptr; if (!prog->jit_requested) @@ -922,7 +996,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) image_ptr = jit_data->image; header = jit_data->header; extra_pass = true; - image_size = sizeof(u32) * ctx.idx; + prog_size = sizeof(u32) * ctx.idx; goto skip_init_ctx; } memset(&ctx, 0, sizeof(ctx)); @@ -950,8 +1024,12 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ctx.epilogue_offset = ctx.idx; build_epilogue(&ctx); + extable_size = prog->aux->num_exentries * + sizeof(struct exception_table_entry); + /* Now we know the actual image size. */ - image_size = sizeof(u32) * ctx.idx; + prog_size = sizeof(u32) * ctx.idx; + image_size = prog_size + extable_size; header = bpf_jit_binary_alloc(image_size, &image_ptr, sizeof(u32), jit_fill_hole); if (header == NULL) { @@ -962,8 +1040,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) /* 2. Now, the actual pass. */ ctx.image = (__le32 *)image_ptr; + if (extable_size) + prog->aux->extable = (void *)image_ptr + prog_size; skip_init_ctx: ctx.idx = 0; + ctx.exentry_idx = 0; build_prologue(&ctx, was_classic); @@ -984,7 +1065,7 @@ skip_init_ctx: /* And we're done. */ if (bpf_jit_enable > 1) - bpf_jit_dump(prog->len, image_size, 2, ctx.image); + bpf_jit_dump(prog->len, prog_size, 2, ctx.image); bpf_flush_icache(header, ctx.image + ctx.idx); @@ -1005,7 +1086,7 @@ skip_init_ctx: } prog->bpf_func = (void *)ctx.image; prog->jited = 1; - prog->jited_len = image_size; + prog->jited_len = prog_size; if (!prog->is_func || extra_pass) { bpf_prog_fill_jited_linfo(prog, ctx.offset);