bluetooth: switch to AES library

The bluetooth code uses a bare AES cipher for the encryption operations.
Given that it carries out a set_key() operation right before every
encryption operation, this is clearly not a hot path, and so the use of
the cipher interface (which provides the best implementation available
on the system) is not really required.

In fact, when using a cipher like AES-NI or AES-CE, both the set_key()
and the encrypt() operations involve en/disabling preemption as well as
stacking and unstacking the SIMD context, and this is most certainly
not worth it for encrypting 16 bytes of data.

So let's switch to the new lightweight library interface instead.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
This commit is contained in:
Ard Biesheuvel 2019-07-02 21:41:41 +02:00 committed by Herbert Xu
parent 0a5dff9882
commit 28a220aac5
2 changed files with 33 additions and 73 deletions

View File

@ -10,7 +10,8 @@ menuconfig BT
select CRC16 select CRC16
select CRYPTO select CRYPTO
select CRYPTO_BLKCIPHER select CRYPTO_BLKCIPHER
select CRYPTO_AES select CRYPTO_LIB_AES
imply CRYPTO_AES
select CRYPTO_CMAC select CRYPTO_CMAC
select CRYPTO_ECB select CRYPTO_ECB
select CRYPTO_SHA256 select CRYPTO_SHA256

View File

@ -23,6 +23,7 @@
#include <linux/debugfs.h> #include <linux/debugfs.h>
#include <linux/scatterlist.h> #include <linux/scatterlist.h>
#include <linux/crypto.h> #include <linux/crypto.h>
#include <crypto/aes.h>
#include <crypto/algapi.h> #include <crypto/algapi.h>
#include <crypto/b128ops.h> #include <crypto/b128ops.h>
#include <crypto/hash.h> #include <crypto/hash.h>
@ -88,7 +89,6 @@ struct smp_dev {
u8 local_rand[16]; u8 local_rand[16];
bool debug_key; bool debug_key;
struct crypto_cipher *tfm_aes;
struct crypto_shash *tfm_cmac; struct crypto_shash *tfm_cmac;
struct crypto_kpp *tfm_ecdh; struct crypto_kpp *tfm_ecdh;
}; };
@ -127,7 +127,6 @@ struct smp_chan {
u8 dhkey[32]; u8 dhkey[32];
u8 mackey[16]; u8 mackey[16];
struct crypto_cipher *tfm_aes;
struct crypto_shash *tfm_cmac; struct crypto_shash *tfm_cmac;
struct crypto_kpp *tfm_ecdh; struct crypto_kpp *tfm_ecdh;
}; };
@ -377,22 +376,18 @@ static int smp_h7(struct crypto_shash *tfm_cmac, const u8 w[16],
* s1 and ah. * s1 and ah.
*/ */
static int smp_e(struct crypto_cipher *tfm, const u8 *k, u8 *r) static int smp_e(const u8 *k, u8 *r)
{ {
struct crypto_aes_ctx ctx;
uint8_t tmp[16], data[16]; uint8_t tmp[16], data[16];
int err; int err;
SMP_DBG("k %16phN r %16phN", k, r); SMP_DBG("k %16phN r %16phN", k, r);
if (!tfm) {
BT_ERR("tfm %p", tfm);
return -EINVAL;
}
/* The most significant octet of key corresponds to k[0] */ /* The most significant octet of key corresponds to k[0] */
swap_buf(k, tmp, 16); swap_buf(k, tmp, 16);
err = crypto_cipher_setkey(tfm, tmp, 16); err = aes_expandkey(&ctx, tmp, 16);
if (err) { if (err) {
BT_ERR("cipher setkey failed: %d", err); BT_ERR("cipher setkey failed: %d", err);
return err; return err;
@ -401,17 +396,18 @@ static int smp_e(struct crypto_cipher *tfm, const u8 *k, u8 *r)
/* Most significant octet of plaintextData corresponds to data[0] */ /* Most significant octet of plaintextData corresponds to data[0] */
swap_buf(r, data, 16); swap_buf(r, data, 16);
crypto_cipher_encrypt_one(tfm, data, data); aes_encrypt(&ctx, data, data);
/* Most significant octet of encryptedData corresponds to data[0] */ /* Most significant octet of encryptedData corresponds to data[0] */
swap_buf(data, r, 16); swap_buf(data, r, 16);
SMP_DBG("r %16phN", r); SMP_DBG("r %16phN", r);
memzero_explicit(&ctx, sizeof (ctx));
return err; return err;
} }
static int smp_c1(struct crypto_cipher *tfm_aes, const u8 k[16], static int smp_c1(const u8 k[16],
const u8 r[16], const u8 preq[7], const u8 pres[7], u8 _iat, const u8 r[16], const u8 preq[7], const u8 pres[7], u8 _iat,
const bdaddr_t *ia, u8 _rat, const bdaddr_t *ra, u8 res[16]) const bdaddr_t *ia, u8 _rat, const bdaddr_t *ra, u8 res[16])
{ {
@ -436,7 +432,7 @@ static int smp_c1(struct crypto_cipher *tfm_aes, const u8 k[16],
u128_xor((u128 *) res, (u128 *) r, (u128 *) p1); u128_xor((u128 *) res, (u128 *) r, (u128 *) p1);
/* res = e(k, res) */ /* res = e(k, res) */
err = smp_e(tfm_aes, k, res); err = smp_e(k, res);
if (err) { if (err) {
BT_ERR("Encrypt data error"); BT_ERR("Encrypt data error");
return err; return err;
@ -453,14 +449,14 @@ static int smp_c1(struct crypto_cipher *tfm_aes, const u8 k[16],
u128_xor((u128 *) res, (u128 *) res, (u128 *) p2); u128_xor((u128 *) res, (u128 *) res, (u128 *) p2);
/* res = e(k, res) */ /* res = e(k, res) */
err = smp_e(tfm_aes, k, res); err = smp_e(k, res);
if (err) if (err)
BT_ERR("Encrypt data error"); BT_ERR("Encrypt data error");
return err; return err;
} }
static int smp_s1(struct crypto_cipher *tfm_aes, const u8 k[16], static int smp_s1(const u8 k[16],
const u8 r1[16], const u8 r2[16], u8 _r[16]) const u8 r1[16], const u8 r2[16], u8 _r[16])
{ {
int err; int err;
@ -469,15 +465,14 @@ static int smp_s1(struct crypto_cipher *tfm_aes, const u8 k[16],
memcpy(_r, r2, 8); memcpy(_r, r2, 8);
memcpy(_r + 8, r1, 8); memcpy(_r + 8, r1, 8);
err = smp_e(tfm_aes, k, _r); err = smp_e(k, _r);
if (err) if (err)
BT_ERR("Encrypt data error"); BT_ERR("Encrypt data error");
return err; return err;
} }
static int smp_ah(struct crypto_cipher *tfm, const u8 irk[16], static int smp_ah(const u8 irk[16], const u8 r[3], u8 res[3])
const u8 r[3], u8 res[3])
{ {
u8 _res[16]; u8 _res[16];
int err; int err;
@ -486,7 +481,7 @@ static int smp_ah(struct crypto_cipher *tfm, const u8 irk[16],
memcpy(_res, r, 3); memcpy(_res, r, 3);
memset(_res + 3, 0, 13); memset(_res + 3, 0, 13);
err = smp_e(tfm, irk, _res); err = smp_e(irk, _res);
if (err) { if (err) {
BT_ERR("Encrypt error"); BT_ERR("Encrypt error");
return err; return err;
@ -518,7 +513,7 @@ bool smp_irk_matches(struct hci_dev *hdev, const u8 irk[16],
BT_DBG("RPA %pMR IRK %*phN", bdaddr, 16, irk); BT_DBG("RPA %pMR IRK %*phN", bdaddr, 16, irk);
err = smp_ah(smp->tfm_aes, irk, &bdaddr->b[3], hash); err = smp_ah(irk, &bdaddr->b[3], hash);
if (err) if (err)
return false; return false;
@ -541,7 +536,7 @@ int smp_generate_rpa(struct hci_dev *hdev, const u8 irk[16], bdaddr_t *rpa)
rpa->b[5] &= 0x3f; /* Clear two most significant bits */ rpa->b[5] &= 0x3f; /* Clear two most significant bits */
rpa->b[5] |= 0x40; /* Set second most significant bit */ rpa->b[5] |= 0x40; /* Set second most significant bit */
err = smp_ah(smp->tfm_aes, irk, &rpa->b[3], rpa->b); err = smp_ah(irk, &rpa->b[3], rpa->b);
if (err < 0) if (err < 0)
return err; return err;
@ -768,7 +763,6 @@ static void smp_chan_destroy(struct l2cap_conn *conn)
kzfree(smp->slave_csrk); kzfree(smp->slave_csrk);
kzfree(smp->link_key); kzfree(smp->link_key);
crypto_free_cipher(smp->tfm_aes);
crypto_free_shash(smp->tfm_cmac); crypto_free_shash(smp->tfm_cmac);
crypto_free_kpp(smp->tfm_ecdh); crypto_free_kpp(smp->tfm_ecdh);
@ -957,7 +951,7 @@ static u8 smp_confirm(struct smp_chan *smp)
BT_DBG("conn %p", conn); BT_DBG("conn %p", conn);
ret = smp_c1(smp->tfm_aes, smp->tk, smp->prnd, smp->preq, smp->prsp, ret = smp_c1(smp->tk, smp->prnd, smp->preq, smp->prsp,
conn->hcon->init_addr_type, &conn->hcon->init_addr, conn->hcon->init_addr_type, &conn->hcon->init_addr,
conn->hcon->resp_addr_type, &conn->hcon->resp_addr, conn->hcon->resp_addr_type, &conn->hcon->resp_addr,
cp.confirm_val); cp.confirm_val);
@ -983,12 +977,9 @@ static u8 smp_random(struct smp_chan *smp)
u8 confirm[16]; u8 confirm[16];
int ret; int ret;
if (IS_ERR_OR_NULL(smp->tfm_aes))
return SMP_UNSPECIFIED;
BT_DBG("conn %p %s", conn, conn->hcon->out ? "master" : "slave"); BT_DBG("conn %p %s", conn, conn->hcon->out ? "master" : "slave");
ret = smp_c1(smp->tfm_aes, smp->tk, smp->rrnd, smp->preq, smp->prsp, ret = smp_c1(smp->tk, smp->rrnd, smp->preq, smp->prsp,
hcon->init_addr_type, &hcon->init_addr, hcon->init_addr_type, &hcon->init_addr,
hcon->resp_addr_type, &hcon->resp_addr, confirm); hcon->resp_addr_type, &hcon->resp_addr, confirm);
if (ret) if (ret)
@ -1005,7 +996,7 @@ static u8 smp_random(struct smp_chan *smp)
__le64 rand = 0; __le64 rand = 0;
__le16 ediv = 0; __le16 ediv = 0;
smp_s1(smp->tfm_aes, smp->tk, smp->rrnd, smp->prnd, stk); smp_s1(smp->tk, smp->rrnd, smp->prnd, stk);
if (test_and_set_bit(HCI_CONN_ENCRYPT_PEND, &hcon->flags)) if (test_and_set_bit(HCI_CONN_ENCRYPT_PEND, &hcon->flags))
return SMP_UNSPECIFIED; return SMP_UNSPECIFIED;
@ -1021,7 +1012,7 @@ static u8 smp_random(struct smp_chan *smp)
smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd), smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd),
smp->prnd); smp->prnd);
smp_s1(smp->tfm_aes, smp->tk, smp->prnd, smp->rrnd, stk); smp_s1(smp->tk, smp->prnd, smp->rrnd, stk);
if (hcon->pending_sec_level == BT_SECURITY_HIGH) if (hcon->pending_sec_level == BT_SECURITY_HIGH)
auth = 1; auth = 1;
@ -1389,16 +1380,10 @@ static struct smp_chan *smp_chan_create(struct l2cap_conn *conn)
if (!smp) if (!smp)
return NULL; return NULL;
smp->tfm_aes = crypto_alloc_cipher("aes", 0, 0);
if (IS_ERR(smp->tfm_aes)) {
BT_ERR("Unable to create AES crypto context");
goto zfree_smp;
}
smp->tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); smp->tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0);
if (IS_ERR(smp->tfm_cmac)) { if (IS_ERR(smp->tfm_cmac)) {
BT_ERR("Unable to create CMAC crypto context"); BT_ERR("Unable to create CMAC crypto context");
goto free_cipher; goto zfree_smp;
} }
smp->tfm_ecdh = crypto_alloc_kpp("ecdh", CRYPTO_ALG_INTERNAL, 0); smp->tfm_ecdh = crypto_alloc_kpp("ecdh", CRYPTO_ALG_INTERNAL, 0);
@ -1420,8 +1405,6 @@ static struct smp_chan *smp_chan_create(struct l2cap_conn *conn)
free_shash: free_shash:
crypto_free_shash(smp->tfm_cmac); crypto_free_shash(smp->tfm_cmac);
free_cipher:
crypto_free_cipher(smp->tfm_aes);
zfree_smp: zfree_smp:
kzfree(smp); kzfree(smp);
return NULL; return NULL;
@ -3232,7 +3215,6 @@ static struct l2cap_chan *smp_add_cid(struct hci_dev *hdev, u16 cid)
{ {
struct l2cap_chan *chan; struct l2cap_chan *chan;
struct smp_dev *smp; struct smp_dev *smp;
struct crypto_cipher *tfm_aes;
struct crypto_shash *tfm_cmac; struct crypto_shash *tfm_cmac;
struct crypto_kpp *tfm_ecdh; struct crypto_kpp *tfm_ecdh;
@ -3245,17 +3227,9 @@ static struct l2cap_chan *smp_add_cid(struct hci_dev *hdev, u16 cid)
if (!smp) if (!smp)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
tfm_aes = crypto_alloc_cipher("aes", 0, 0);
if (IS_ERR(tfm_aes)) {
BT_ERR("Unable to create AES crypto context");
kzfree(smp);
return ERR_CAST(tfm_aes);
}
tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0);
if (IS_ERR(tfm_cmac)) { if (IS_ERR(tfm_cmac)) {
BT_ERR("Unable to create CMAC crypto context"); BT_ERR("Unable to create CMAC crypto context");
crypto_free_cipher(tfm_aes);
kzfree(smp); kzfree(smp);
return ERR_CAST(tfm_cmac); return ERR_CAST(tfm_cmac);
} }
@ -3264,13 +3238,11 @@ static struct l2cap_chan *smp_add_cid(struct hci_dev *hdev, u16 cid)
if (IS_ERR(tfm_ecdh)) { if (IS_ERR(tfm_ecdh)) {
BT_ERR("Unable to create ECDH crypto context"); BT_ERR("Unable to create ECDH crypto context");
crypto_free_shash(tfm_cmac); crypto_free_shash(tfm_cmac);
crypto_free_cipher(tfm_aes);
kzfree(smp); kzfree(smp);
return ERR_CAST(tfm_ecdh); return ERR_CAST(tfm_ecdh);
} }
smp->local_oob = false; smp->local_oob = false;
smp->tfm_aes = tfm_aes;
smp->tfm_cmac = tfm_cmac; smp->tfm_cmac = tfm_cmac;
smp->tfm_ecdh = tfm_ecdh; smp->tfm_ecdh = tfm_ecdh;
@ -3278,7 +3250,6 @@ create_chan:
chan = l2cap_chan_create(); chan = l2cap_chan_create();
if (!chan) { if (!chan) {
if (smp) { if (smp) {
crypto_free_cipher(smp->tfm_aes);
crypto_free_shash(smp->tfm_cmac); crypto_free_shash(smp->tfm_cmac);
crypto_free_kpp(smp->tfm_ecdh); crypto_free_kpp(smp->tfm_ecdh);
kzfree(smp); kzfree(smp);
@ -3326,7 +3297,6 @@ static void smp_del_chan(struct l2cap_chan *chan)
smp = chan->data; smp = chan->data;
if (smp) { if (smp) {
chan->data = NULL; chan->data = NULL;
crypto_free_cipher(smp->tfm_aes);
crypto_free_shash(smp->tfm_cmac); crypto_free_shash(smp->tfm_cmac);
crypto_free_kpp(smp->tfm_ecdh); crypto_free_kpp(smp->tfm_ecdh);
kzfree(smp); kzfree(smp);
@ -3582,7 +3552,7 @@ static int __init test_debug_key(struct crypto_kpp *tfm_ecdh)
return 0; return 0;
} }
static int __init test_ah(struct crypto_cipher *tfm_aes) static int __init test_ah(void)
{ {
const u8 irk[16] = { const u8 irk[16] = {
0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34, 0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34,
@ -3592,7 +3562,7 @@ static int __init test_ah(struct crypto_cipher *tfm_aes)
u8 res[3]; u8 res[3];
int err; int err;
err = smp_ah(tfm_aes, irk, r, res); err = smp_ah(irk, r, res);
if (err) if (err)
return err; return err;
@ -3602,7 +3572,7 @@ static int __init test_ah(struct crypto_cipher *tfm_aes)
return 0; return 0;
} }
static int __init test_c1(struct crypto_cipher *tfm_aes) static int __init test_c1(void)
{ {
const u8 k[16] = { const u8 k[16] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@ -3622,7 +3592,7 @@ static int __init test_c1(struct crypto_cipher *tfm_aes)
u8 res[16]; u8 res[16];
int err; int err;
err = smp_c1(tfm_aes, k, r, preq, pres, _iat, &ia, _rat, &ra, res); err = smp_c1(k, r, preq, pres, _iat, &ia, _rat, &ra, res);
if (err) if (err)
return err; return err;
@ -3632,7 +3602,7 @@ static int __init test_c1(struct crypto_cipher *tfm_aes)
return 0; return 0;
} }
static int __init test_s1(struct crypto_cipher *tfm_aes) static int __init test_s1(void)
{ {
const u8 k[16] = { const u8 k[16] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@ -3647,7 +3617,7 @@ static int __init test_s1(struct crypto_cipher *tfm_aes)
u8 res[16]; u8 res[16];
int err; int err;
err = smp_s1(tfm_aes, k, r1, r2, res); err = smp_s1(k, r1, r2, res);
if (err) if (err)
return err; return err;
@ -3828,8 +3798,7 @@ static const struct file_operations test_smp_fops = {
.llseek = default_llseek, .llseek = default_llseek,
}; };
static int __init run_selftests(struct crypto_cipher *tfm_aes, static int __init run_selftests(struct crypto_shash *tfm_cmac,
struct crypto_shash *tfm_cmac,
struct crypto_kpp *tfm_ecdh) struct crypto_kpp *tfm_ecdh)
{ {
ktime_t calltime, delta, rettime; ktime_t calltime, delta, rettime;
@ -3844,19 +3813,19 @@ static int __init run_selftests(struct crypto_cipher *tfm_aes,
goto done; goto done;
} }
err = test_ah(tfm_aes); err = test_ah();
if (err) { if (err) {
BT_ERR("smp_ah test failed"); BT_ERR("smp_ah test failed");
goto done; goto done;
} }
err = test_c1(tfm_aes); err = test_c1();
if (err) { if (err) {
BT_ERR("smp_c1 test failed"); BT_ERR("smp_c1 test failed");
goto done; goto done;
} }
err = test_s1(tfm_aes); err = test_s1();
if (err) { if (err) {
BT_ERR("smp_s1 test failed"); BT_ERR("smp_s1 test failed");
goto done; goto done;
@ -3913,21 +3882,13 @@ done:
int __init bt_selftest_smp(void) int __init bt_selftest_smp(void)
{ {
struct crypto_cipher *tfm_aes;
struct crypto_shash *tfm_cmac; struct crypto_shash *tfm_cmac;
struct crypto_kpp *tfm_ecdh; struct crypto_kpp *tfm_ecdh;
int err; int err;
tfm_aes = crypto_alloc_cipher("aes", 0, 0);
if (IS_ERR(tfm_aes)) {
BT_ERR("Unable to create AES crypto context");
return PTR_ERR(tfm_aes);
}
tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0);
if (IS_ERR(tfm_cmac)) { if (IS_ERR(tfm_cmac)) {
BT_ERR("Unable to create CMAC crypto context"); BT_ERR("Unable to create CMAC crypto context");
crypto_free_cipher(tfm_aes);
return PTR_ERR(tfm_cmac); return PTR_ERR(tfm_cmac);
} }
@ -3935,14 +3896,12 @@ int __init bt_selftest_smp(void)
if (IS_ERR(tfm_ecdh)) { if (IS_ERR(tfm_ecdh)) {
BT_ERR("Unable to create ECDH crypto context"); BT_ERR("Unable to create ECDH crypto context");
crypto_free_shash(tfm_cmac); crypto_free_shash(tfm_cmac);
crypto_free_cipher(tfm_aes);
return PTR_ERR(tfm_ecdh); return PTR_ERR(tfm_ecdh);
} }
err = run_selftests(tfm_aes, tfm_cmac, tfm_ecdh); err = run_selftests(tfm_cmac, tfm_ecdh);
crypto_free_shash(tfm_cmac); crypto_free_shash(tfm_cmac);
crypto_free_cipher(tfm_aes);
crypto_free_kpp(tfm_ecdh); crypto_free_kpp(tfm_ecdh);
return err; return err;