xref: /kernel/linux/linux-6.6/fs/hmdfs/comm/crypto.c (revision 62306a36)
1// SPDX-License-Identifier: GPL-2.0
2/*
3 * fs/hmdfs/comm/crypto.c
4 *
5 * Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6 */
7
8#include "crypto.h"
9
10#include <crypto/aead.h>
11#include <crypto/hash.h>
12#include <linux/tcp.h>
13#include <net/inet_connection_sock.h>
14#include <net/tcp_states.h>
15#include <net/tls.h>
16
17#include "hmdfs.h"
18
19static void tls_crypto_set_key(struct connection *conn_impl, int tx)
20{
21	int rc = 0;
22	struct tcp_handle *tcp = conn_impl->connect_handle;
23	struct tls_context *ctx = NULL;
24	struct cipher_context *cctx = NULL;
25	struct tls_sw_context_tx *sw_ctx_tx = NULL;
26	struct tls_sw_context_rx *sw_ctx_rx = NULL;
27	struct crypto_aead **aead = NULL;
28	struct tls12_crypto_info_aes_gcm_128 *crypto_info = NULL;
29
30	lock_sock(tcp->sock->sk);
31	ctx = tls_get_ctx(tcp->sock->sk);
32	if (tx) {
33		crypto_info = &conn_impl->send_crypto_info;
34		cctx = &ctx->tx;
35		sw_ctx_tx = tls_sw_ctx_tx(ctx);
36		aead = &sw_ctx_tx->aead_send;
37	} else {
38		crypto_info = &conn_impl->recv_crypto_info;
39		cctx = &ctx->rx;
40		sw_ctx_rx = tls_sw_ctx_rx(ctx);
41		aead = &sw_ctx_rx->aead_recv;
42	}
43
44	memcpy(cctx->iv, crypto_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
45	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, crypto_info->iv,
46	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
47	memcpy(cctx->rec_seq, crypto_info->rec_seq,
48	       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
49	rc = crypto_aead_setkey(*aead, crypto_info->key,
50				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
51	if (rc)
52		hmdfs_err("crypto set key error");
53	release_sock(tcp->sock->sk);
54}
55
56int tls_crypto_info_init(struct connection *conn_impl)
57{
58	int ret = 0;
59	u8 key_meterial[HMDFS_KEY_SIZE];
60	struct tcp_handle *tcp =
61		(struct tcp_handle *)(conn_impl->connect_handle);
62	if (!tcp)
63		return -EINVAL;
64	// send
65	update_key(conn_impl->send_key, key_meterial, HKDF_TYPE_IV);
66	ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TCP, TCP_ULP,
67							KERNEL_SOCKPTR("tls"), sizeof("tls"));
68	if (ret)
69		hmdfs_err("set tls error %d", ret);
70	tcp->connect->send_crypto_info.info.version = TLS_1_2_VERSION;
71	tcp->connect->send_crypto_info.info.cipher_type =
72		TLS_CIPHER_AES_GCM_128;
73
74	memcpy(tcp->connect->send_crypto_info.key, tcp->connect->send_key,
75	       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
76	memcpy(tcp->connect->send_crypto_info.iv,
77	       key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
78	memcpy(tcp->connect->send_crypto_info.salt,
79	       key_meterial + CRYPTO_SALT_OFFSET,
80	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
81	memcpy(tcp->connect->send_crypto_info.rec_seq,
82	       key_meterial + CRYPTO_SEQ_OFFSET,
83	       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
84
85	ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TLS, TLS_TX,
86				KERNEL_SOCKPTR(&(tcp->connect->send_crypto_info)),
87				sizeof(tcp->connect->send_crypto_info));
88	if (ret)
89		hmdfs_err("set tls send_crypto_info error %d", ret);
90
91	// recv
92	update_key(tcp->connect->recv_key, key_meterial, HKDF_TYPE_IV);
93	tcp->connect->recv_crypto_info.info.version = TLS_1_2_VERSION;
94	tcp->connect->recv_crypto_info.info.cipher_type =
95		TLS_CIPHER_AES_GCM_128;
96
97	memcpy(tcp->connect->recv_crypto_info.key, tcp->connect->recv_key,
98	       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
99	memcpy(tcp->connect->recv_crypto_info.iv,
100	       key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
101	memcpy(tcp->connect->recv_crypto_info.salt,
102	       key_meterial + CRYPTO_SALT_OFFSET,
103	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
104	memcpy(tcp->connect->recv_crypto_info.rec_seq,
105	       key_meterial + CRYPTO_SEQ_OFFSET,
106	       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
107	memset(key_meterial, 0, HMDFS_KEY_SIZE);
108
109	ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TLS, TLS_RX,
110				KERNEL_SOCKPTR(&(tcp->connect->recv_crypto_info)),
111				sizeof(tcp->connect->recv_crypto_info));
112	if (ret)
113		hmdfs_err("set tls recv_crypto_info error %d", ret);
114	return ret;
115}
116
117static int tls_set_tx(struct tcp_handle *tcp)
118{
119	int ret = 0;
120	u8 new_key[HMDFS_KEY_SIZE];
121	u8 key_meterial[HMDFS_KEY_SIZE];
122
123	ret = update_key(tcp->connect->send_key, new_key, HKDF_TYPE_REKEY);
124	if (ret < 0)
125		return ret;
126	memcpy(tcp->connect->send_key, new_key, HMDFS_KEY_SIZE);
127	ret = update_key(tcp->connect->send_key, key_meterial, HKDF_TYPE_IV);
128	if (ret < 0)
129		return ret;
130
131	memcpy(tcp->connect->send_crypto_info.key, tcp->connect->send_key,
132	       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
133	memcpy(tcp->connect->send_crypto_info.iv,
134	       key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
135	memcpy(tcp->connect->send_crypto_info.salt,
136	       key_meterial + CRYPTO_SALT_OFFSET,
137	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
138	memcpy(tcp->connect->send_crypto_info.rec_seq,
139	       key_meterial + CRYPTO_SEQ_OFFSET,
140	       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
141	memset(new_key, 0, HMDFS_KEY_SIZE);
142	memset(key_meterial, 0, HMDFS_KEY_SIZE);
143
144	tls_crypto_set_key(tcp->connect, 1);
145	return 0;
146}
147
148static int tls_set_rx(struct tcp_handle *tcp)
149{
150	int ret = 0;
151	u8 new_key[HMDFS_KEY_SIZE];
152	u8 key_meterial[HMDFS_KEY_SIZE];
153
154	ret = update_key(tcp->connect->recv_key, new_key, HKDF_TYPE_REKEY);
155	if (ret < 0)
156		return ret;
157	memcpy(tcp->connect->recv_key, new_key, HMDFS_KEY_SIZE);
158	ret = update_key(tcp->connect->recv_key, key_meterial, HKDF_TYPE_IV);
159	if (ret < 0)
160		return ret;
161
162	memcpy(tcp->connect->recv_crypto_info.key, tcp->connect->recv_key,
163	       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
164	memcpy(tcp->connect->recv_crypto_info.iv,
165	       key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
166	memcpy(tcp->connect->recv_crypto_info.salt,
167	       key_meterial + CRYPTO_SALT_OFFSET,
168	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
169	memcpy(tcp->connect->recv_crypto_info.rec_seq,
170	       key_meterial + CRYPTO_SEQ_OFFSET,
171	       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
172	memset(new_key, 0, HMDFS_KEY_SIZE);
173	memset(key_meterial, 0, HMDFS_KEY_SIZE);
174	tls_crypto_set_key(tcp->connect, 0);
175	return 0;
176}
177
178int set_crypto_info(struct connection *conn_impl, int set_type)
179{
180	int ret = 0;
181	struct tcp_handle *tcp =
182		(struct tcp_handle *)(conn_impl->connect_handle);
183	if (!tcp)
184		return -EINVAL;
185
186	if (set_type == SET_CRYPTO_SEND) {
187		ret = tls_set_tx(tcp);
188		if (ret) {
189			hmdfs_err("tls set tx fail");
190			return ret;
191		}
192	}
193	if (set_type == SET_CRYPTO_RECV) {
194		ret = tls_set_rx(tcp);
195		if (ret) {
196			hmdfs_err("tls set rx fail");
197			return ret;
198		}
199	}
200	hmdfs_info("KTLS setting success");
201	return ret;
202}
203
204static int hmac_sha256(u8 *key, u8 key_len, char *info, u8 info_len, u8 *output)
205{
206	struct crypto_shash *tfm = NULL;
207	struct shash_desc *shash = NULL;
208	int ret = 0;
209
210	if (!key)
211		return -EINVAL;
212
213	tfm = crypto_alloc_shash("hmac(sha256)", 0, 0);
214	if (IS_ERR(tfm)) {
215		hmdfs_err("crypto_alloc_ahash failed: err %ld", PTR_ERR(tfm));
216		return PTR_ERR(tfm);
217	}
218
219	ret = crypto_shash_setkey(tfm, key, key_len);
220	if (ret) {
221		hmdfs_err("crypto_ahash_setkey failed: err %d", ret);
222		goto failed;
223	}
224
225	shash = kzalloc(sizeof(*shash) + crypto_shash_descsize(tfm),
226			GFP_KERNEL);
227	if (!shash) {
228		ret = -ENOMEM;
229		goto failed;
230	}
231
232	shash->tfm = tfm;
233
234	ret = crypto_shash_digest(shash, info, info_len, output);
235
236	kfree(shash);
237
238failed:
239	crypto_free_shash(tfm);
240	return ret;
241}
242
243static const char *const g_key_lable[] = { "ktls key initiator",
244					   "ktls key accepter",
245					   "ktls key update", "ktls iv&salt" };
246static const int g_key_lable_len[] = { 18, 17, 15, 12 };
247
248int update_key(__u8 *old_key, __u8 *new_key, int type)
249{
250	int ret = 0;
251	char lable[MAX_LABLE_SIZE];
252	u8 lable_size;
253
254	lable_size = g_key_lable_len[type] + sizeof(u16) + sizeof(char);
255	*((u16 *)lable) = HMDFS_KEY_SIZE;
256	memcpy(lable + sizeof(u16), g_key_lable[type], g_key_lable_len[type]);
257	*(lable + sizeof(u16) + g_key_lable_len[type]) = 0x01;
258	ret = hmac_sha256(old_key, HMDFS_KEY_SIZE, lable, lable_size, new_key);
259	if (ret < 0)
260		hmdfs_err("hmac sha256 error");
261	return ret;
262}
263