xref: /kernel/linux/linux-6.6/fs/hmdfs/comm/transport.c (revision 62306a36)
1// SPDX-License-Identifier: GPL-2.0
2/*
3 * fs/hmdfs/comm/transport.c
4 *
5 * Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6 */
7
8#include "transport.h"
9
10#include <linux/freezer.h>
11#include <linux/highmem.h>
12#include <linux/kthread.h>
13#include <linux/module.h>
14#include <linux/net.h>
15#include <linux/tcp.h>
16#include <linux/time.h>
17#include <linux/file.h>
18#include <linux/sched/mm.h>
19
20#include "device_node.h"
21#include "hmdfs_trace.h"
22#include "socket_adapter.h"
23#include "authority/authentication.h"
24
25#ifdef CONFIG_HMDFS_FS_ENCRYPTION
26#include <net/tls.h>
27#include "crypto.h"
28#endif
29
30typedef void (*connect_recv_handler)(struct connection *, void *, void *,
31				     __u32);
32
33static connect_recv_handler connect_recv_callback[CONNECT_STAT_COUNT] = {
34	[CONNECT_STAT_WAIT_REQUEST] = connection_handshake_recv_handler,
35	[CONNECT_STAT_WAIT_RESPONSE] = connection_handshake_recv_handler,
36	[CONNECT_STAT_WORKING] = connection_working_recv_handler,
37	[CONNECT_STAT_STOP] = NULL,
38	[CONNECT_STAT_WAIT_ACK] = connection_handshake_recv_handler,
39	[CONNECT_STAT_NEGO_FAIL] = NULL,
40};
41
42static int recvmsg_nofs(struct socket *sock, struct msghdr *msg,
43			struct kvec *vec, size_t num, size_t size, int flags)
44{
45	unsigned int nofs_flags;
46	int ret;
47
48	/* enable NOFS for memory allocation */
49	nofs_flags = memalloc_nofs_save();
50	ret = kernel_recvmsg(sock, msg, vec, num, size, flags);
51	memalloc_nofs_restore(nofs_flags);
52
53	return ret;
54}
55
56static int sendmsg_nofs(struct socket *sock, struct msghdr *msg,
57			struct kvec *vec, size_t num, size_t size)
58{
59	unsigned int nofs_flags;
60	int ret;
61
62	/* enable NOFS for memory allocation */
63	nofs_flags = memalloc_nofs_save();
64	ret = kernel_sendmsg(sock, msg, vec, num, size);
65	memalloc_nofs_restore(nofs_flags);
66
67	return ret;
68}
69
70static int tcp_set_recvtimeo(struct socket *sock, int timeout)
71{
72	long jiffies_left = timeout * msecs_to_jiffies(MSEC_PER_SEC);
73
74	tcp_sock_set_nodelay(sock->sk);
75	tcp_sock_set_user_timeout(sock->sk, jiffies_left);
76	return 0;
77}
78
79uint32_t hmdfs_tcpi_rtt(struct hmdfs_peer *con)
80{
81	uint32_t rtt_us = 0;
82	struct connection *conn_impl = NULL;
83	struct tcp_handle *tcp = NULL;
84
85	conn_impl = get_conn_impl(con, CONNECT_TYPE_TCP);
86	if (!conn_impl)
87		return rtt_us;
88	tcp = (struct tcp_handle *)(conn_impl->connect_handle);
89	if (tcp->sock)
90		rtt_us = tcp_sk(tcp->sock->sk)->srtt_us >> 3;
91	connection_put(conn_impl);
92	return rtt_us;
93}
94
95static int tcp_read_head_from_socket(struct socket *sock, void *buf,
96				     unsigned int to_read)
97{
98	int rc = 0;
99	struct msghdr hmdfs_msg;
100	struct kvec iov;
101
102	iov.iov_base = buf;
103	iov.iov_len = to_read;
104	memset(&hmdfs_msg, 0, sizeof(hmdfs_msg));
105	hmdfs_msg.msg_flags = MSG_WAITALL;
106	hmdfs_msg.msg_control = NULL;
107	hmdfs_msg.msg_controllen = 0;
108	rc = recvmsg_nofs(sock, &hmdfs_msg, &iov, 1, to_read,
109			  hmdfs_msg.msg_flags);
110	if (rc == -EAGAIN || rc == -ETIMEDOUT || rc == -EINTR ||
111	    rc == -EBADMSG) {
112		usleep_range(1000, 2000);
113		return -EAGAIN;
114	}
115	// error occurred
116	if (rc != to_read) {
117		hmdfs_err("tcp recv error %d", rc);
118		return -ESHUTDOWN;
119	}
120	return 0;
121}
122
123static int tcp_read_buffer_from_socket(struct socket *sock, void *buf,
124				       unsigned int to_read)
125{
126	int read_cnt = 0;
127	int retry_time = 0;
128	int rc = 0;
129	struct msghdr hmdfs_msg;
130	struct kvec iov;
131
132	do {
133		iov.iov_base = (char *)buf + read_cnt;
134		iov.iov_len = to_read - read_cnt;
135		memset(&hmdfs_msg, 0, sizeof(hmdfs_msg));
136		hmdfs_msg.msg_flags = MSG_WAITALL;
137		hmdfs_msg.msg_control = NULL;
138		hmdfs_msg.msg_controllen = 0;
139		rc = recvmsg_nofs(sock, &hmdfs_msg, &iov, 1,
140				  to_read - read_cnt, hmdfs_msg.msg_flags);
141		if (rc == -EBADMSG) {
142			usleep_range(1000, 2000);
143			continue;
144		}
145		if (rc == -EAGAIN || rc == -ETIMEDOUT || rc == -EINTR) {
146			retry_time++;
147			hmdfs_info("read again %d", rc);
148			usleep_range(1000, 2000);
149			continue;
150		}
151		// error occurred
152		if (rc <= 0) {
153			hmdfs_err("tcp recv error %d", rc);
154			return -ESHUTDOWN;
155		}
156		read_cnt += rc;
157		if (read_cnt != to_read)
158			hmdfs_info("read again %d/%d", read_cnt, to_read);
159	} while (read_cnt < to_read && retry_time < MAX_RECV_RETRY_TIMES);
160	if (read_cnt == to_read)
161		return 0;
162	return -ESHUTDOWN;
163}
164
165static int hmdfs_drop_readpage_buffer(struct socket *sock,
166				      struct hmdfs_head_cmd *recv)
167{
168	unsigned int len;
169	void *buf = NULL;
170	int err;
171
172	len = le32_to_cpu(recv->data_len) - sizeof(struct hmdfs_head_cmd);
173	if (len > HMDFS_PAGE_SIZE || !len) {
174		hmdfs_err("recv invalid readpage length %u", len);
175		return -EINVAL;
176	}
177
178	/* Abort the connection if no memory */
179	buf = kmalloc(len, GFP_KERNEL);
180	if (!buf)
181		return -ESHUTDOWN;
182
183	err = tcp_read_buffer_from_socket(sock, buf, len);
184	kfree(buf);
185
186	return err;
187}
188
189static int hmdfs_get_readpage_buffer(struct socket *sock,
190				     struct hmdfs_head_cmd *recv,
191				     struct page *page)
192{
193	char *page_buf = NULL;
194	unsigned int out_len;
195	int err;
196
197	out_len = le32_to_cpu(recv->data_len) - sizeof(struct hmdfs_head_cmd);
198	if (out_len > HMDFS_PAGE_SIZE || !out_len) {
199		hmdfs_err("recv invalid readpage length %u", out_len);
200		return -EINVAL;
201	}
202
203	page_buf = kmap(page);
204	err = tcp_read_buffer_from_socket(sock, page_buf, out_len);
205	if (err)
206		goto out_unmap;
207	if (out_len != HMDFS_PAGE_SIZE)
208		memset(page_buf + out_len, 0, HMDFS_PAGE_SIZE - out_len);
209
210out_unmap:
211	kunmap(page);
212	return err;
213}
214
215static int tcp_recvpage_tls(struct connection *connect,
216			    struct hmdfs_head_cmd *recv)
217{
218	int ret = 0;
219	struct tcp_handle *tcp = NULL;
220	struct hmdfs_peer *node = NULL;
221	struct page *page = NULL;
222	struct hmdfs_async_work *async_work = NULL;
223	int rd_err;
224
225	if (!connect) {
226		hmdfs_err("tcp connect == NULL");
227		return -ESHUTDOWN;
228	}
229	node = connect->node;
230	tcp = (struct tcp_handle *)(connect->connect_handle);
231
232	rd_err = le32_to_cpu(recv->ret_code);
233	if (rd_err)
234		hmdfs_warning("tcp: readpage from peer %llu ret err %d",
235			      node->device_id, rd_err);
236
237	async_work = (struct hmdfs_async_work *)hmdfs_find_msg_head(node,
238						le32_to_cpu(recv->msg_id), recv->operations);
239	if (!async_work || !cancel_delayed_work(&async_work->d_work))
240		goto out;
241
242	page = async_work->page;
243	if (!page) {
244		hmdfs_err("page not found");
245		goto out;
246	}
247
248	if (!rd_err) {
249		ret = hmdfs_get_readpage_buffer(tcp->sock, recv, page);
250		if (ret)
251			rd_err = ret;
252	}
253	hmdfs_client_recv_readpage(recv, rd_err, async_work);
254	asw_put(async_work);
255	return ret;
256
257out:
258	/* async_work will be released by recvpage in normal processure */
259	if (async_work)
260		asw_put(async_work);
261	hmdfs_err_ratelimited("timeout and droppage");
262	hmdfs_client_resp_statis(node->sbi, F_READPAGE, HMDFS_RESP_DELAY, 0, 0);
263	if (!rd_err)
264		ret = hmdfs_drop_readpage_buffer(tcp->sock, recv);
265	return ret;
266}
267
268static void aeadcipher_cb(void *req, int error)
269{
270	struct aeadcrypt_result *result = ((struct crypto_async_request *)req)->data;
271
272	if (error == -EINPROGRESS)
273		return;
274	result->err = error;
275	complete(&result->completion);
276}
277
278static int aeadcipher_en_de(struct aead_request *req,
279			    struct aeadcrypt_result result, int flag)
280{
281	int rc = 0;
282
283	if (flag)
284		rc = crypto_aead_encrypt(req);
285	else
286		rc = crypto_aead_decrypt(req);
287	switch (rc) {
288	case 0:
289		break;
290	case -EINPROGRESS:
291	case -EBUSY:
292		rc = wait_for_completion_interruptible(&result.completion);
293		if (!rc && !result.err)
294			reinit_completion(&result.completion);
295		break;
296	default:
297		hmdfs_err("returned rc %d result %d", rc, result.err);
298		break;
299	}
300	return rc;
301}
302
303static int set_aeadcipher(struct crypto_aead *tfm, struct aead_request *req,
304			  struct aeadcrypt_result *result)
305{
306	init_completion(&result->completion);
307	aead_request_set_callback(
308		req, CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP,
309		aeadcipher_cb, result);
310	return 0;
311}
312
313int aeadcipher_encrypt_buffer(struct connection *con, __u8 *src_buf,
314			      size_t src_len, __u8 *dst_buf, size_t dst_len)
315{
316	int ret = 0;
317	struct scatterlist src, dst;
318	struct aead_request *req = NULL;
319	struct aeadcrypt_result result;
320	__u8 cipher_iv[HMDFS_IV_SIZE];
321
322	if (src_len <= 0)
323		return -EINVAL;
324	if (!virt_addr_valid(src_buf) || !virt_addr_valid(dst_buf)) {
325		WARN_ON(1);
326		hmdfs_err("encrypt address is invalid");
327		return -EPERM;
328	}
329
330	get_random_bytes(cipher_iv, HMDFS_IV_SIZE);
331	memcpy(dst_buf, cipher_iv, HMDFS_IV_SIZE);
332	req = aead_request_alloc(con->tfm, GFP_KERNEL);
333	if (!req) {
334		hmdfs_err("aead_request_alloc() failed");
335		return -ENOMEM;
336	}
337	ret = set_aeadcipher(con->tfm, req, &result);
338	if (ret) {
339		hmdfs_err("set_enaeadcipher exit fault");
340		goto out;
341	}
342
343	sg_init_one(&src, src_buf, src_len);
344	sg_init_one(&dst, dst_buf + HMDFS_IV_SIZE, dst_len - HMDFS_IV_SIZE);
345	aead_request_set_crypt(req, &src, &dst, src_len, cipher_iv);
346	aead_request_set_ad(req, 0);
347	ret = aeadcipher_en_de(req, result, ENCRYPT_FLAG);
348out:
349	aead_request_free(req);
350	return ret;
351}
352
353int aeadcipher_decrypt_buffer(struct connection *con, __u8 *src_buf,
354			      size_t src_len, __u8 *dst_buf, size_t dst_len)
355{
356	int ret = 0;
357	struct scatterlist src, dst;
358	struct aead_request *req = NULL;
359	struct aeadcrypt_result result;
360	__u8 cipher_iv[HMDFS_IV_SIZE];
361
362	if (src_len <= HMDFS_IV_SIZE + HMDFS_TAG_SIZE)
363		return -EINVAL;
364	if (!virt_addr_valid(src_buf) || !virt_addr_valid(dst_buf)) {
365		WARN_ON(1);
366		hmdfs_err("decrypt address is invalid");
367		return -EPERM;
368	}
369
370	memcpy(cipher_iv, src_buf, HMDFS_IV_SIZE);
371	req = aead_request_alloc(con->tfm, GFP_KERNEL);
372	if (!req) {
373		hmdfs_err("aead_request_alloc() failed");
374		return -ENOMEM;
375	}
376	ret = set_aeadcipher(con->tfm, req, &result);
377	if (ret) {
378		hmdfs_err("set_deaeadcipher exit fault");
379		goto out;
380	}
381
382	sg_init_one(&src, src_buf + HMDFS_IV_SIZE, src_len - HMDFS_IV_SIZE);
383	sg_init_one(&dst, dst_buf, dst_len);
384	aead_request_set_crypt(req, &src, &dst, src_len - HMDFS_IV_SIZE,
385			       cipher_iv);
386	aead_request_set_ad(req, 0);
387	ret = aeadcipher_en_de(req, result, DECRYPT_FLAG);
388out:
389	aead_request_free(req);
390	return ret;
391}
392
393static int tcp_recvbuffer_cipher(struct connection *connect,
394				 struct hmdfs_head_cmd *recv)
395{
396	int ret = 0;
397	struct tcp_handle *tcp = NULL;
398	size_t cipherbuffer_len;
399	__u8 *cipherbuffer = NULL;
400	size_t outlen = 0;
401	__u8 *outdata = NULL;
402	__u32 recv_len = le32_to_cpu(recv->data_len);
403
404	tcp = (struct tcp_handle *)(connect->connect_handle);
405	if (recv_len == sizeof(struct hmdfs_head_cmd))
406		goto out_recv_head;
407	else if (recv_len > sizeof(struct hmdfs_head_cmd) &&
408	    recv_len <= ADAPTER_MESSAGE_LENGTH)
409		cipherbuffer_len = recv_len - sizeof(struct hmdfs_head_cmd) +
410				   HMDFS_IV_SIZE + HMDFS_TAG_SIZE;
411	else
412		return -ENOMSG;
413	cipherbuffer = kzalloc(cipherbuffer_len, GFP_KERNEL);
414	if (!cipherbuffer) {
415		hmdfs_err("zalloc cipherbuffer error");
416		return -ESHUTDOWN;
417	}
418	outlen = cipherbuffer_len - HMDFS_IV_SIZE - HMDFS_TAG_SIZE;
419	outdata = kzalloc(outlen, GFP_KERNEL);
420	if (!outdata) {
421		hmdfs_err("encrypt zalloc outdata error");
422		kfree(cipherbuffer);
423		return -ESHUTDOWN;
424	}
425
426	ret = tcp_read_buffer_from_socket(tcp->sock, cipherbuffer,
427					  cipherbuffer_len);
428	if (ret)
429		goto out_recv;
430	ret = aeadcipher_decrypt_buffer(connect, cipherbuffer, cipherbuffer_len,
431					outdata, outlen);
432	if (ret) {
433		hmdfs_err("decrypt_buf fail");
434		goto out_recv;
435	}
436out_recv_head:
437	if (connect_recv_callback[connect->status]) {
438		connect_recv_callback[connect->status](connect, recv, outdata,
439						       outlen);
440	} else {
441		kfree(outdata);
442		hmdfs_err("encypt callback NULL status %d", connect->status);
443	}
444	kfree(cipherbuffer);
445	return ret;
446out_recv:
447	kfree(cipherbuffer);
448	kfree(outdata);
449	return ret;
450}
451
452static int tcp_recvbuffer_tls(struct connection *connect,
453			      struct hmdfs_head_cmd *recv)
454{
455	int ret = 0;
456	struct tcp_handle *tcp = NULL;
457	size_t outlen;
458	__u8 *outdata = NULL;
459	__u32 recv_len = le32_to_cpu(recv->data_len);
460
461	tcp = (struct tcp_handle *)(connect->connect_handle);
462	outlen = recv_len - sizeof(struct hmdfs_head_cmd);
463	if (outlen == 0)
464		goto out_recv_head;
465
466	/*
467	 * NOTE: Up to half of the allocated memory may be wasted due to
468	 * the Internal Fragmentation, however the memory allocation times
469	 * can be reduced and we don't have to adjust existing message
470	 * transporting mechanism
471	 */
472	outdata = kmalloc(outlen, GFP_KERNEL);
473	if (!outdata)
474		return -ESHUTDOWN;
475
476	ret = tcp_read_buffer_from_socket(tcp->sock, outdata, outlen);
477	if (ret) {
478		kfree(outdata);
479		return ret;
480	}
481	tcp->connect->stat.recv_bytes += outlen;
482out_recv_head:
483	if (connect_recv_callback[connect->status]) {
484		connect_recv_callback[connect->status](connect, recv, outdata,
485						       outlen);
486	} else {
487		kfree(outdata);
488		hmdfs_err("callback NULL status %d", connect->status);
489	}
490	return 0;
491}
492
493static int tcp_receive_from_sock(struct tcp_handle *tcp)
494{
495	struct hmdfs_head_cmd *recv = NULL;
496	int ret = 0;
497
498	if (!tcp) {
499		hmdfs_info("tcp recv thread !tcp");
500		return -ESHUTDOWN;
501	}
502
503	if (!tcp->sock) {
504		hmdfs_info("tcp recv thread !sock");
505		return -ESHUTDOWN;
506	}
507
508	recv = kmem_cache_alloc(tcp->recv_cache, GFP_KERNEL);
509	if (!recv) {
510		hmdfs_info("tcp recv thread !cache");
511		return -ESHUTDOWN;
512	}
513
514	ret = tcp_read_head_from_socket(tcp->sock, recv,
515					sizeof(struct hmdfs_head_cmd));
516	if (ret)
517		goto out;
518
519	tcp->connect->stat.recv_bytes += sizeof(struct hmdfs_head_cmd);
520	tcp->connect->stat.recv_message_count++;
521
522	if (recv->magic != HMDFS_MSG_MAGIC || recv->version != HMDFS_VERSION) {
523		hmdfs_info_ratelimited("tcp recv fd %d wrong magic. drop message",
524				       tcp->fd);
525		goto out;
526	}
527
528	if ((le32_to_cpu(recv->data_len) >
529	    HMDFS_MAX_MESSAGE_LEN + sizeof(struct hmdfs_head_cmd)) ||
530	    (le32_to_cpu(recv->data_len) < sizeof(struct hmdfs_head_cmd))) {
531		hmdfs_info("tcp recv fd %d length error. drop message",
532			   tcp->fd);
533		goto out;
534	}
535
536	if (tcp->connect->status == CONNECT_STAT_WORKING &&
537	    recv->operations.command == F_READPAGE &&
538	    recv->operations.cmd_flag == C_RESPONSE) {
539		ret = tcp_recvpage_tls(tcp->connect, recv);
540		goto out;
541	}
542
543	if (tcp->connect->status == CONNECT_STAT_WORKING)
544		ret = tcp_recvbuffer_tls(tcp->connect, recv);
545	else
546		ret = tcp_recvbuffer_cipher(tcp->connect, recv);
547
548out:
549	kmem_cache_free(tcp->recv_cache, recv);
550	return ret;
551}
552
553static bool tcp_handle_is_available(struct tcp_handle *tcp)
554{
555#ifdef CONFIG_HMDFS_FS_ENCRYPTION
556	struct tls_context *tls_ctx = NULL;
557	struct tls_sw_context_rx *ctx = NULL;
558
559#endif
560	if (!tcp || !tcp->sock || !tcp->sock->sk) {
561		hmdfs_err("Invalid tcp connection");
562		return false;
563	}
564
565	if (tcp->sock->sk->sk_state != TCP_ESTABLISHED) {
566		hmdfs_err("TCP conn %d is broken, current sk_state is %d",
567			  tcp->fd, tcp->sock->sk->sk_state);
568		return false;
569	}
570
571	if (tcp->sock->state != SS_CONNECTING &&
572	    tcp->sock->state != SS_CONNECTED) {
573		hmdfs_err("TCP conn %d is broken, current sock state is %d",
574			  tcp->fd, tcp->sock->state);
575		return false;
576	}
577
578#ifdef CONFIG_HMDFS_FS_ENCRYPTION
579	tls_ctx = tls_get_ctx(tcp->sock->sk);
580	if (tls_ctx) {
581		ctx = tls_sw_ctx_rx(tls_ctx);
582		if (ctx && ctx->strp.stopped) {
583			hmdfs_err(
584				"TCP conn %d is broken, the strparser has stopped",
585				tcp->fd);
586			return false;
587		}
588	}
589#endif
590	return true;
591}
592
593static int tcp_recv_thread(void *arg)
594{
595	int ret = 0;
596	struct tcp_handle *tcp = (struct tcp_handle *)arg;
597	const struct cred *old_cred;
598
599	WARN_ON(!tcp);
600	WARN_ON(!tcp->sock);
601	set_freezable();
602
603	old_cred = hmdfs_override_creds(tcp->connect->node->sbi->system_cred);
604
605	while (!kthread_should_stop()) {
606		/*
607		 * 1. In case the redundant connection has not been mounted on
608		 *    a peer
609		 * 2. Lock is unnecessary since a transient state is acceptable
610		 */
611		if (tcp_handle_is_available(tcp) &&
612		    list_empty(&tcp->connect->list))
613			goto freeze;
614		if (!mutex_trylock(&tcp->close_mutex))
615			continue;
616		if (tcp_handle_is_available(tcp))
617			ret = tcp_receive_from_sock(tcp);
618		else
619			ret = -ESHUTDOWN;
620		/*
621		 * This kthread will exit if ret is -ESHUTDOWN, thus we need to
622		 * set recv_task to NULL to avoid calling kthread_stop() from
623		 * tcp_close_socket().
624		 */
625		if (ret == -ESHUTDOWN)
626			tcp->recv_task = NULL;
627		mutex_unlock(&tcp->close_mutex);
628		if (ret == -ESHUTDOWN) {
629			hmdfs_node_inc_evt_seq(tcp->connect->node);
630			tcp->connect->status = CONNECT_STAT_STOP;
631			if (tcp->connect->node->status != NODE_STAT_OFFLINE)
632				hmdfs_reget_connection(tcp->connect);
633			break;
634		}
635freeze:
636		schedule();
637		try_to_freeze();
638	}
639
640	hmdfs_info("Exiting. Now, sock state = %d", tcp->sock->state);
641	hmdfs_revert_creds(old_cred);
642	connection_put(tcp->connect);
643	return 0;
644}
645
646static int tcp_send_message_sock_cipher(struct tcp_handle *tcp,
647					struct hmdfs_send_data *msg)
648{
649	int ret = 0;
650	__u8 *outdata = NULL;
651	size_t outlen = 0;
652	int send_len = 0;
653	int send_vec_cnt = 0;
654	struct msghdr tcp_msg;
655	struct kvec iov[TCP_KVEC_ELE_DOUBLE];
656
657	memset(&tcp_msg, 0, sizeof(tcp_msg));
658	if (!tcp || !tcp->sock) {
659		hmdfs_err("encrypt tcp socket = NULL");
660		return -ESHUTDOWN;
661	}
662	iov[0].iov_base = msg->head;
663	iov[0].iov_len = msg->head_len;
664	send_vec_cnt = TCP_KVEC_HEAD;
665	if (msg->len == 0)
666		goto send;
667
668	outlen = msg->len + HMDFS_IV_SIZE + HMDFS_TAG_SIZE;
669	outdata = kzalloc(outlen, GFP_KERNEL);
670	if (!outdata) {
671		hmdfs_err("tcp send message encrypt fail to alloc outdata");
672		return -ENOMEM;
673	}
674	ret = aeadcipher_encrypt_buffer(tcp->connect, msg->data, msg->len,
675					outdata, outlen);
676	if (ret) {
677		hmdfs_err("encrypt_buf fail");
678		goto out;
679	}
680	iov[1].iov_base = outdata;
681	iov[1].iov_len = outlen;
682	send_vec_cnt = TCP_KVEC_ELE_DOUBLE;
683send:
684	mutex_lock(&tcp->send_mutex);
685	send_len = sendmsg_nofs(tcp->sock, &tcp_msg, iov, send_vec_cnt,
686				msg->head_len + outlen);
687	mutex_unlock(&tcp->send_mutex);
688	if (send_len <= 0) {
689		hmdfs_err("error %d", send_len);
690		ret = -ESHUTDOWN;
691	} else if (send_len != msg->head_len + outlen) {
692		hmdfs_err("send part of message. %d/%zu", send_len,
693			  msg->head_len + outlen);
694		ret = -EAGAIN;
695	} else {
696		ret = 0;
697	}
698out:
699	kfree(outdata);
700	return ret;
701}
702
703static int tcp_send_message_sock_tls(struct tcp_handle *tcp,
704				     struct hmdfs_send_data *msg)
705{
706	int send_len = 0;
707	int send_vec_cnt = 0;
708	struct msghdr tcp_msg;
709	struct kvec iov[TCP_KVEC_ELE_TRIPLE];
710
711	memset(&tcp_msg, 0, sizeof(tcp_msg));
712	if (!tcp || !tcp->sock) {
713		hmdfs_err("tcp socket = NULL");
714		return -ESHUTDOWN;
715	}
716	iov[TCP_KVEC_HEAD].iov_base = msg->head;
717	iov[TCP_KVEC_HEAD].iov_len = msg->head_len;
718	if (msg->len == 0 && msg->sdesc_len == 0) {
719		send_vec_cnt = TCP_KVEC_ELE_SINGLE;
720	} else if (msg->sdesc_len == 0) {
721		iov[TCP_KVEC_DATA].iov_base = msg->data;
722		iov[TCP_KVEC_DATA].iov_len = msg->len;
723		send_vec_cnt = TCP_KVEC_ELE_DOUBLE;
724	} else {
725		iov[TCP_KVEC_FILE_PARA].iov_base = msg->sdesc;
726		iov[TCP_KVEC_FILE_PARA].iov_len = msg->sdesc_len;
727		iov[TCP_KVEC_FILE_CONTENT].iov_base = msg->data;
728		iov[TCP_KVEC_FILE_CONTENT].iov_len = msg->len;
729		send_vec_cnt = TCP_KVEC_ELE_TRIPLE;
730	}
731	mutex_lock(&tcp->send_mutex);
732	send_len = sendmsg_nofs(tcp->sock, &tcp_msg, iov, send_vec_cnt,
733				msg->head_len + msg->len + msg->sdesc_len);
734	mutex_unlock(&tcp->send_mutex);
735	if (send_len == -EBADMSG) {
736		return -EBADMSG;
737	} else if (send_len <= 0) {
738		hmdfs_err("error %d", send_len);
739		return -ESHUTDOWN;
740	} else if (send_len != msg->head_len + msg->len + msg->sdesc_len) {
741		hmdfs_err("send part of message. %d/%zu", send_len,
742			  msg->head_len + msg->len);
743		tcp->connect->stat.send_bytes += send_len;
744		return -EAGAIN;
745	}
746	tcp->connect->stat.send_bytes += send_len;
747	tcp->connect->stat.send_message_count++;
748	return 0;
749}
750
751#ifdef CONFIG_HMDFS_FS_ENCRYPTION
752int tcp_send_rekey_request(struct connection *connect)
753{
754	int ret = 0;
755	struct hmdfs_send_data msg;
756	struct tcp_handle *tcp = connect->connect_handle;
757	struct hmdfs_head_cmd *head = NULL;
758	struct connection_rekey_request *rekey_request_param = NULL;
759	struct hmdfs_cmd operations;
760
761	hmdfs_init_cmd(&operations, F_CONNECT_REKEY);
762	head = kzalloc(sizeof(struct hmdfs_head_cmd) +
763			       sizeof(struct connection_rekey_request),
764		       GFP_KERNEL);
765	if (!head)
766		return -ENOMEM;
767	rekey_request_param =
768		(struct connection_rekey_request
769			 *)((uint8_t *)head + sizeof(struct hmdfs_head_cmd));
770
771	rekey_request_param->update_request = cpu_to_le32(UPDATE_NOT_REQUESTED);
772
773	head->magic = HMDFS_MSG_MAGIC;
774	head->version = HMDFS_VERSION;
775	head->operations = operations;
776	head->data_len =
777		cpu_to_le32(sizeof(*head) + sizeof(*rekey_request_param));
778	head->reserved = 0;
779	head->reserved1 = 0;
780	head->ret_code = 0;
781
782	msg.head = head;
783	msg.head_len = sizeof(*head);
784	msg.data = rekey_request_param;
785	msg.len = sizeof(*rekey_request_param);
786	msg.sdesc = NULL;
787	msg.sdesc_len = 0;
788	ret = tcp_send_message_sock_tls(tcp, &msg);
789	if (ret != 0)
790		hmdfs_err("return error %d", ret);
791	kfree(head);
792	return ret;
793}
794#endif
795
796static int tcp_send_message(struct connection *connect,
797			    struct hmdfs_send_data *msg)
798{
799	int ret = 0;
800#ifdef CONFIG_HMDFS_FS_ENCRYPTION
801	unsigned long nowtime = jiffies;
802#endif
803	struct tcp_handle *tcp = NULL;
804
805	if (!connect) {
806		hmdfs_err("tcp connection = NULL ");
807		return -ESHUTDOWN;
808	}
809	if (!msg) {
810		hmdfs_err("msg = NULL");
811		return -EINVAL;
812	}
813	if (msg->len > HMDFS_MAX_MESSAGE_LEN) {
814		hmdfs_err("message->len error: %zu", msg->len);
815		return -EINVAL;
816	}
817	tcp = (struct tcp_handle *)(connect->connect_handle);
818	if (connect->status == CONNECT_STAT_STOP)
819		return -EAGAIN;
820
821	trace_hmdfs_tcp_send_message(msg->head);
822
823	if (connect->status == CONNECT_STAT_WORKING)
824		ret = tcp_send_message_sock_tls(tcp, msg);
825	else
826		ret = tcp_send_message_sock_cipher(tcp, msg);
827
828	if (ret != 0) {
829		hmdfs_err("return error %d", ret);
830		return ret;
831	}
832#ifdef CONFIG_HMDFS_FS_ENCRYPTION
833	if (nowtime - connect->stat.rekey_time >= REKEY_LIFETIME &&
834	    connect->status == CONNECT_STAT_WORKING) {
835		hmdfs_info("send rekey message to devid %llu",
836			   connect->node->device_id);
837		ret = tcp_send_rekey_request(connect);
838		if (ret == 0)
839			set_crypto_info(connect, SET_CRYPTO_SEND);
840		connect->stat.rekey_time = nowtime;
841	}
842#endif
843	return ret;
844}
845
846void tcp_close_socket(struct tcp_handle *tcp)
847{
848	int ret;
849	if (!tcp)
850		return;
851	mutex_lock(&tcp->close_mutex);
852	if (tcp->recv_task) {
853		ret = kthread_stop(tcp->recv_task);
854		/* recv_task killed before sched, we need to put the connect */
855		if (ret == -EINTR)
856			connection_put(tcp->connect);
857		tcp->recv_task = NULL;
858	}
859	mutex_unlock(&tcp->close_mutex);
860}
861
862static int set_tfm(__u8 *master_key, struct crypto_aead *tfm)
863{
864	int ret = 0;
865	int iv_len;
866	__u8 *sec_key = NULL;
867
868	sec_key = master_key;
869	crypto_aead_clear_flags(tfm, ~0);
870	ret = crypto_aead_setkey(tfm, sec_key, HMDFS_KEY_SIZE);
871	if (ret) {
872		hmdfs_err("failed to set the key");
873		goto out;
874	}
875	ret = crypto_aead_setauthsize(tfm, HMDFS_TAG_SIZE);
876	if (ret) {
877		hmdfs_err("authsize length is error");
878		goto out;
879	}
880
881	iv_len = crypto_aead_ivsize(tfm);
882	if (iv_len != HMDFS_IV_SIZE) {
883		hmdfs_err("IV recommended value should be set %d", iv_len);
884		ret = -ENODATA;
885	}
886out:
887	return ret;
888}
889
890static bool is_tcp_socket(struct tcp_handle *tcp)
891{
892	struct inet_connection_sock *icsk;
893
894	if (!tcp || !tcp->sock || !tcp->sock->sk) {
895		hmdfs_err("invalid tcp handle");
896		return false;
897	}
898
899	lock_sock(tcp->sock->sk);
900	if (tcp->sock->sk->sk_protocol != IPPROTO_TCP ||
901	    tcp->sock->type != SOCK_STREAM ||
902	    tcp->sock->sk->sk_family != AF_INET) {
903		hmdfs_err("invalid socket protocol");
904		release_sock(tcp->sock->sk);
905		return false;
906	}
907
908	icsk = inet_csk(tcp->sock->sk);
909	if (icsk->icsk_ulp_ops) {
910		hmdfs_err("ulp not NULL");
911		release_sock(tcp->sock->sk);
912		return false;
913	}
914
915	release_sock(tcp->sock->sk);
916	return true;
917}
918
919static int tcp_update_socket(struct tcp_handle *tcp, int fd,
920			     uint8_t *master_key, struct socket *socket)
921{
922	int err = 0;
923	struct hmdfs_peer *node = NULL;
924
925	if (!master_key || fd == 0)
926		return -EAGAIN;
927
928	tcp->sock = socket;
929	tcp->fd = fd;
930
931	if (!is_tcp_socket(tcp)) {
932		err = -EINVAL;
933		goto put_sock;
934	}
935
936	if (!tcp_handle_is_available(tcp)) {
937		err = -EPIPE;
938		goto put_sock;
939	}
940
941	hmdfs_info("socket fd %d, state %d, refcount %ld protocol %d", fd,
942		   socket->state, file_count(socket->file),
943		   socket->sk->sk_protocol);
944
945	tcp->recv_cache = kmem_cache_create("hmdfs_socket",
946					    tcp->recvbuf_maxsize,
947					    0, SLAB_HWCACHE_ALIGN, NULL);
948	if (!tcp->recv_cache) {
949		err = -ENOMEM;
950		goto put_sock;
951	}
952
953	err = tcp_set_recvtimeo(socket, TCP_RECV_TIMEOUT);
954	if (err) {
955		hmdfs_err("tcp set timeout error");
956		goto free_mem_cache;
957	}
958
959	/* send key and recv key, default MASTER KEY */
960	memcpy(tcp->connect->master_key, master_key, HMDFS_KEY_SIZE);
961	memcpy(tcp->connect->send_key, master_key, HMDFS_KEY_SIZE);
962	memcpy(tcp->connect->recv_key, master_key, HMDFS_KEY_SIZE);
963	tcp->connect->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
964	if (IS_ERR(tcp->connect->tfm)) {
965		err = PTR_ERR(tcp->connect->tfm);
966		tcp->connect->tfm = NULL;
967		hmdfs_err("failed to load transform for gcm(aes):%d", err);
968		goto free_mem_cache;
969	}
970
971	err = set_tfm(master_key, tcp->connect->tfm);
972	if (err) {
973		hmdfs_err("tfm seting exit fault");
974		goto free_crypto;
975	}
976
977	connection_get(tcp->connect);
978
979	node = tcp->connect->node;
980	tcp->recv_task = kthread_create(tcp_recv_thread, (void *)tcp,
981					"dfs_rcv%u_%llu_%d",
982					node->owner, node->device_id, fd);
983	if (IS_ERR(tcp->recv_task)) {
984		err = PTR_ERR(tcp->recv_task);
985		hmdfs_err("tcp->rcev_task %d", err);
986		goto put_conn;
987	}
988
989	return 0;
990
991put_conn:
992	tcp->recv_task = NULL;
993	connection_put(tcp->connect);
994free_crypto:
995	crypto_free_aead(tcp->connect->tfm);
996	tcp->connect->tfm = NULL;
997free_mem_cache:
998	kmem_cache_destroy(tcp->recv_cache);
999	tcp->recv_cache = NULL;
1000put_sock:
1001	tcp->sock = NULL;
1002	tcp->fd = 0;
1003
1004	return err;
1005}
1006
1007static struct tcp_handle *tcp_alloc_handle(struct connection *connect,
1008	int socket_fd, uint8_t *master_key, struct socket *socket)
1009{
1010	int ret = 0;
1011	struct tcp_handle *tcp = kzalloc(sizeof(*tcp), GFP_KERNEL);
1012
1013	if (!tcp)
1014		return NULL;
1015	tcp->connect = connect;
1016	tcp->connect->connect_handle = (void *)tcp;
1017	tcp->recvbuf_maxsize = MAX_RECV_SIZE;
1018	tcp->recv_task = NULL;
1019	tcp->recv_cache = NULL;
1020	tcp->sock = NULL;
1021	mutex_init(&tcp->close_mutex);
1022	mutex_init(&tcp->send_mutex);
1023	ret = tcp_update_socket(tcp, socket_fd, master_key, socket);
1024	if (ret) {
1025		kfree(tcp);
1026		return NULL;
1027	}
1028	return tcp;
1029}
1030
1031void hmdfs_get_connection(struct hmdfs_peer *peer)
1032{
1033	struct notify_param param;
1034
1035	if (!peer)
1036		return;
1037	param.notify = NOTIFY_GET_SESSION;
1038	param.fd = INVALID_SOCKET_FD;
1039	memcpy(param.remote_cid, peer->cid, HMDFS_CID_SIZE);
1040	notify(peer, &param);
1041}
1042
1043static void connection_notify_to_close(struct connection *conn)
1044{
1045	struct notify_param param;
1046	struct hmdfs_peer *peer = NULL;
1047	struct tcp_handle *tcp = NULL;
1048
1049	tcp = conn->connect_handle;
1050	peer = conn->node;
1051
1052	// libdistbus/src/TcpSession.cpp will close the socket
1053	param.notify = NOTIFY_GET_SESSION;
1054	param.fd = tcp->fd;
1055	memcpy(param.remote_cid, peer->cid, HMDFS_CID_SIZE);
1056	notify(peer, &param);
1057}
1058
1059void hmdfs_reget_connection(struct connection *conn)
1060{
1061	struct tcp_handle *tcp = NULL;
1062	struct connection *conn_impl = NULL;
1063	struct connection *next = NULL;
1064	struct task_struct *recv_task = NULL;
1065	bool should_put = false;
1066	bool stop_thread = true;
1067
1068	if (!conn)
1069		return;
1070
1071	// One may put a connection if and only if he took it out of the list
1072	mutex_lock(&conn->node->conn_impl_list_lock);
1073	list_for_each_entry_safe(conn_impl, next, &conn->node->conn_impl_list,
1074				  list) {
1075		if (conn_impl == conn) {
1076			should_put = true;
1077			list_move(&conn->list, &conn->node->conn_deleting_list);
1078			break;
1079		}
1080	}
1081	if (!should_put) {
1082		mutex_unlock(&conn->node->conn_impl_list_lock);
1083		return;
1084	}
1085
1086	tcp = conn->connect_handle;
1087	if (tcp) {
1088		recv_task = tcp->recv_task;
1089		/*
1090		 * To avoid the receive thread to stop itself. Ensure receive
1091		 * thread stop before process offline event
1092		 */
1093		if (!recv_task || recv_task->pid == current->pid)
1094			stop_thread = false;
1095	}
1096	mutex_unlock(&conn->node->conn_impl_list_lock);
1097
1098	if (tcp) {
1099		if (tcp->sock) {
1100			hmdfs_info("shudown sock: fd = %d, sockref = %ld, connref = %u stop_thread = %d",
1101				   tcp->fd, file_count(tcp->sock->file),
1102				   kref_read(&conn->ref_cnt), stop_thread);
1103			kernel_sock_shutdown(tcp->sock, SHUT_RDWR);
1104		}
1105
1106		if (stop_thread)
1107			tcp_close_socket(tcp);
1108
1109		if (tcp->fd != INVALID_SOCKET_FD)
1110			connection_notify_to_close(conn);
1111	}
1112	connection_put(conn);
1113}
1114
1115static struct connection *
1116lookup_conn_by_socketfd_unsafe(struct hmdfs_peer *node, struct socket *socket)
1117{
1118	struct connection *tcp_conn = NULL;
1119	struct tcp_handle *tcp = NULL;
1120
1121	list_for_each_entry(tcp_conn, &node->conn_impl_list, list) {
1122		if (tcp_conn->connect_handle) {
1123			tcp = (struct tcp_handle *)(tcp_conn->connect_handle);
1124			if (tcp->sock == socket) {
1125				connection_get(tcp_conn);
1126				return tcp_conn;
1127			}
1128		}
1129	}
1130	return NULL;
1131}
1132
1133static void hmdfs_reget_connection_work_fn(struct work_struct *work)
1134{
1135	struct connection *conn =
1136		container_of(work, struct connection, reget_work);
1137
1138	hmdfs_reget_connection(conn);
1139	connection_put(conn);
1140}
1141
1142struct connection *alloc_conn_tcp(struct hmdfs_peer *node, int socket_fd,
1143				  uint8_t *master_key, uint8_t status, struct socket *socket)
1144{
1145	struct connection *tcp_conn = NULL;
1146	unsigned long nowtime = jiffies;
1147
1148	tcp_conn = kzalloc(sizeof(*tcp_conn), GFP_KERNEL);
1149	if (!tcp_conn)
1150		goto out_err;
1151
1152	kref_init(&tcp_conn->ref_cnt);
1153	mutex_init(&tcp_conn->ref_lock);
1154	INIT_LIST_HEAD(&tcp_conn->list);
1155	tcp_conn->node = node;
1156	tcp_conn->close = tcp_stop_connect;
1157	tcp_conn->send_message = tcp_send_message;
1158	tcp_conn->type = CONNECT_TYPE_TCP;
1159	tcp_conn->status = status;
1160	tcp_conn->stat.rekey_time = nowtime;
1161	tcp_conn->connect_handle =
1162		(void *)tcp_alloc_handle(tcp_conn, socket_fd, master_key, socket);
1163	INIT_WORK(&tcp_conn->reget_work, hmdfs_reget_connection_work_fn);
1164	if (!tcp_conn->connect_handle) {
1165		hmdfs_err("Failed to alloc tcp_handle for strcut conn");
1166		goto out_err;
1167	}
1168	return tcp_conn;
1169
1170out_err:
1171	kfree(tcp_conn);
1172	return NULL;
1173}
1174
1175static struct connection *add_conn_tcp_unsafe(struct hmdfs_peer *node,
1176					      struct socket *socket,
1177					      struct connection *conn2add)
1178{
1179	struct connection *conn;
1180
1181	conn = lookup_conn_by_socketfd_unsafe(node, socket);
1182	if (conn) {
1183		hmdfs_info("socket already in list");
1184		return conn;
1185	}
1186
1187	/* Prefer to use socket opened by local device */
1188	if (conn2add->status == CONNECT_STAT_WAIT_REQUEST)
1189		list_add(&conn2add->list, &node->conn_impl_list);
1190	else
1191		list_add_tail(&conn2add->list, &node->conn_impl_list);
1192	connection_get(conn2add);
1193	return conn2add;
1194}
1195
1196struct connection *hmdfs_get_conn_tcp(struct hmdfs_peer *node, int fd,
1197				      uint8_t *master_key, uint8_t status)
1198{
1199	struct connection *tcp_conn = NULL, *on_peer_conn = NULL;
1200	struct tcp_handle *tcp = NULL;
1201	struct socket *socket = NULL;
1202	int err = 0;
1203
1204	socket = sockfd_lookup(fd, &err);
1205	if (!socket) {
1206		hmdfs_err("lookup socket fail, socket_fd %d, err %d", fd, err);
1207		return NULL;
1208	}
1209	mutex_lock(&node->conn_impl_list_lock);
1210	tcp_conn = lookup_conn_by_socketfd_unsafe(node, socket);
1211	mutex_unlock(&node->conn_impl_list_lock);
1212	if (tcp_conn) {
1213		hmdfs_info("Got a existing tcp conn: fsocket_fd = %d",
1214			   fd);
1215		sockfd_put(socket);
1216		goto out;
1217	}
1218
1219	tcp_conn = alloc_conn_tcp(node, fd, master_key, status, socket);
1220	if (!tcp_conn) {
1221		hmdfs_info("Failed to alloc a tcp conn, socket_fd %d", fd);
1222		sockfd_put(socket);
1223		goto out;
1224	}
1225
1226	mutex_lock(&node->conn_impl_list_lock);
1227	on_peer_conn = add_conn_tcp_unsafe(node, socket, tcp_conn);
1228	mutex_unlock(&node->conn_impl_list_lock);
1229	tcp = tcp_conn->connect_handle;
1230	if (on_peer_conn == tcp_conn) {
1231		hmdfs_info("Got a newly allocated tcp conn: socket_fd = %d", fd);
1232		wake_up_process(tcp->recv_task);
1233		if (status == CONNECT_STAT_WAIT_RESPONSE)
1234			connection_send_handshake(
1235				on_peer_conn, CONNECT_MESG_HANDSHAKE_REQUEST,
1236				0);
1237	} else {
1238		hmdfs_info("Got a existing tcp conn: socket_fd = %d", fd);
1239		tcp->fd = INVALID_SOCKET_FD;
1240		tcp_close_socket(tcp);
1241		connection_put(tcp_conn);
1242
1243		tcp_conn = on_peer_conn;
1244	}
1245
1246out:
1247	return tcp_conn;
1248}
1249
1250void tcp_stop_connect(struct connection *connect)
1251{
1252	hmdfs_info("now nothing to do");
1253}
1254