1// SPDX-License-Identifier: GPL-2.0
2/*
3 * fs/hmdfs/comm/connection.c
4 *
5 * Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6 */
7
8#include "connection.h"
9
10#include <linux/file.h>
11#include <linux/freezer.h>
12#include <linux/fs.h>
13#include <linux/kthread.h>
14#include <linux/module.h>
15#include <linux/net.h>
16#include <linux/tcp.h>
17#include <linux/workqueue.h>
18
19#include "device_node.h"
20#include "hmdfs.h"
21#include "message_verify.h"
22#include "node_cb.h"
23#include "protocol.h"
24#include "socket_adapter.h"
25
26#ifdef CONFIG_HMDFS_FS_ENCRYPTION
27#include "crypto.h"
28#endif
29
30#define HMDFS_WAIT_REQUEST_END_MIN 20
31#define HMDFS_WAIT_REQUEST_END_MAX 30
32
33#define HMDFS_WAIT_CONN_RELEASE (3 * HZ)
34
35#define HMDFS_RETRY_WB_WQ_MAX_ACTIVE 16
36
37static void hs_fill_crypto_data(struct connection *conn_impl, __u8 ops,
38				void *data, __u32 len)
39{
40	struct crypto_body *body = NULL;
41
42	if (len < sizeof(struct crypto_body)) {
43		hmdfs_info("crpto body len %u is err", len);
44		return;
45	}
46	body = (struct crypto_body *)data;
47
48	/* this is only test, later need to fill right algorithm. */
49	body->crypto |= HMDFS_HS_CRYPTO_KTLS_AES128;
50	body->crypto = cpu_to_le32(body->crypto);
51
52	hmdfs_info("fill crypto. ccrtypto=0x%08x", body->crypto);
53}
54
55static int hs_parse_crypto_data(struct connection *conn_impl, __u8 ops,
56				 void *data, __u32 len)
57{
58	struct crypto_body *hs_crypto = NULL;
59	uint32_t crypto;
60
61	if (len < sizeof(struct crypto_body)) {
62		hmdfs_info("handshake msg len error, len=%u", len);
63		return -1;
64	}
65	hs_crypto = (struct crypto_body *)data;
66	crypto = le16_to_cpu(hs_crypto->crypto);
67	conn_impl->crypto = crypto;
68	hmdfs_info("ops=%u, len=%u, crypto=0x%08x", ops, len, crypto);
69	return 0;
70}
71
72static void hs_fill_case_sense_data(struct connection *conn_impl, __u8 ops,
73				    void *data, __u32 len)
74{
75	struct case_sense_body *body = (struct case_sense_body *)data;
76
77	if (len < sizeof(struct case_sense_body)) {
78		hmdfs_err("case sensitive len %u is err", len);
79		return;
80	}
81	body->case_sensitive = conn_impl->node->sbi->s_case_sensitive;
82}
83
84static int hs_parse_case_sense_data(struct connection *conn_impl, __u8 ops,
85				     void *data, __u32 len)
86{
87	struct case_sense_body *body = (struct case_sense_body *)data;
88	__u8 sensitive = conn_impl->node->sbi->s_case_sensitive ? 1 : 0;
89
90	if (len < sizeof(struct case_sense_body)) {
91		hmdfs_info("case sensitive len %u is err", len);
92		return -1;
93	}
94	if (body->case_sensitive != sensitive) {
95		hmdfs_err("case sensitive inconsistent, server: %u,client: %u, ops: %u",
96			  body->case_sensitive, sensitive, ops);
97		return -1;
98	}
99	return 0;
100}
101
102static void hs_fill_feature_data(struct connection *conn_impl, __u8 ops,
103				 void *data, __u32 len)
104{
105	struct feature_body *body = (struct feature_body *)data;
106
107	if (len < sizeof(struct feature_body)) {
108		hmdfs_err("feature len %u is err", len);
109		return;
110	}
111	body->features = cpu_to_le64(conn_impl->node->sbi->s_features);
112	body->reserved = cpu_to_le64(0);
113}
114
115static int hs_parse_feature_data(struct connection *conn_impl, __u8 ops,
116				 void *data, __u32 len)
117{
118	struct feature_body *body = (struct feature_body *)data;
119
120	if (len < sizeof(struct feature_body)) {
121		hmdfs_err("feature len %u is err", len);
122		return -1;
123	}
124
125	conn_impl->node->features = le64_to_cpu(body->features);
126	return 0;
127}
128
129/* should ensure len is small than 0xffff. */
130static const struct conn_hs_extend_reg s_hs_extend_reg[HS_EXTEND_CODE_COUNT] = {
131	[HS_EXTEND_CODE_CRYPTO] = {
132		.len = sizeof(struct crypto_body),
133		.resv = 0,
134		.filler = hs_fill_crypto_data,
135		.parser = hs_parse_crypto_data
136	},
137	[HS_EXTEND_CODE_CASE_SENSE] = {
138		.len = sizeof(struct case_sense_body),
139		.resv = 0,
140		.filler = hs_fill_case_sense_data,
141		.parser = hs_parse_case_sense_data,
142	},
143	[HS_EXTEND_CODE_FEATURE_SUPPORT] = {
144		.len = sizeof(struct feature_body),
145		.resv = 0,
146		.filler = hs_fill_feature_data,
147		.parser = hs_parse_feature_data,
148	},
149	[HS_EXTEND_CODE_FEATURE_SUPPORT] = {
150		.len = sizeof(struct feature_body),
151		.resv = 0,
152		.filler = hs_fill_feature_data,
153		.parser = hs_parse_feature_data,
154	},
155};
156
157static __u32 hs_get_extend_data_len(void)
158{
159	__u32 len;
160	int i;
161
162	len = sizeof(struct conn_hs_extend_head);
163
164	for (i = 0; i < HS_EXTEND_CODE_COUNT; i++) {
165		len += sizeof(struct extend_field_head);
166		len += s_hs_extend_reg[i].len;
167	}
168
169	hmdfs_info("extend data total len is %u", len);
170	return len;
171}
172
173static void hs_fill_extend_data(struct connection *conn_impl, __u8 ops,
174				void *extend_data, __u32 len)
175{
176	struct conn_hs_extend_head *extend_head = NULL;
177	struct extend_field_head *field = NULL;
178	uint8_t *body = NULL;
179	__u32 offset;
180	__u16 i;
181
182	if (sizeof(struct conn_hs_extend_head) > len) {
183		hmdfs_info("len error. len=%u", len);
184		return;
185	}
186	extend_head = (struct conn_hs_extend_head *)extend_data;
187	extend_head->field_cn = 0;
188	offset = sizeof(struct conn_hs_extend_head);
189
190	for (i = 0; i < HS_EXTEND_CODE_COUNT; i++) {
191		if (sizeof(struct extend_field_head) > (len - offset))
192			break;
193		field = (struct extend_field_head *)((uint8_t *)extend_data +
194						     offset);
195		offset += sizeof(struct extend_field_head);
196
197		if (s_hs_extend_reg[i].len > (len - offset))
198			break;
199		body = (uint8_t *)extend_data + offset;
200		offset += s_hs_extend_reg[i].len;
201
202		field->code = cpu_to_le16(i);
203		field->len = cpu_to_le16(s_hs_extend_reg[i].len);
204
205		if (s_hs_extend_reg[i].filler)
206			s_hs_extend_reg[i].filler(conn_impl, ops,
207					body, s_hs_extend_reg[i].len);
208
209		extend_head->field_cn += 1;
210	}
211
212	extend_head->field_cn = cpu_to_le32(extend_head->field_cn);
213}
214
215static int hs_parse_extend_data(struct connection *conn_impl, __u8 ops,
216				void *extend_data, __u32 extend_len)
217{
218	struct conn_hs_extend_head *extend_head = NULL;
219	struct extend_field_head *field = NULL;
220	uint8_t *body = NULL;
221	__u32 offset;
222	__u32 field_cnt;
223	__u16 code;
224	__u16 len;
225	int i;
226	int ret;
227
228	if (sizeof(struct conn_hs_extend_head) > extend_len) {
229		hmdfs_err("ops=%u,extend_len=%u", ops, extend_len);
230		return -1;
231	}
232	extend_head = (struct conn_hs_extend_head *)extend_data;
233	field_cnt = le32_to_cpu(extend_head->field_cn);
234	hmdfs_info("extend_len=%u,field_cnt=%u", extend_len, field_cnt);
235
236	offset = sizeof(struct conn_hs_extend_head);
237
238	for (i = 0; i < field_cnt; i++) {
239		if (sizeof(struct extend_field_head) > (extend_len - offset)) {
240			hmdfs_err("cnt err, op=%u, extend_len=%u, cnt=%u, i=%u",
241				  ops, extend_len, field_cnt, i);
242			return -1;
243		}
244		field = (struct extend_field_head *)((uint8_t *)extend_data +
245						     offset);
246		offset += sizeof(struct extend_field_head);
247		code = le16_to_cpu(field->code);
248		len = le16_to_cpu(field->len);
249		if (len > (extend_len - offset)) {
250			hmdfs_err("len err, op=%u, extend_len=%u, cnt=%u, i=%u",
251				  ops, extend_len, field_cnt, i);
252			hmdfs_err("len err, code=%u, len=%u, offset=%u", code,
253				  len, offset);
254			return -1;
255		}
256
257		body = (uint8_t *)extend_data + offset;
258		offset += len;
259		if ((code < HS_EXTEND_CODE_COUNT) &&
260		    (s_hs_extend_reg[code].parser)) {
261			ret = s_hs_extend_reg[code].parser(conn_impl, ops,
262							   body, len);
263			if (ret)
264				return ret;
265		}
266	}
267	return 0;
268}
269
270static int hs_proc_msg_data(struct connection *conn_impl, __u8 ops, void *data,
271			    __u32 data_len)
272{
273	struct connection_handshake_req *hs_req = NULL;
274	uint8_t *extend_data = NULL;
275	__u32 extend_len;
276	__u32 req_len;
277	int ret;
278
279	if (!data) {
280		hmdfs_err("err, msg data is null");
281		return -1;
282	}
283
284	if (data_len < sizeof(struct connection_handshake_req)) {
285		hmdfs_err("ack msg data len error. data_len=%u, device_id=%llu",
286			  data_len, conn_impl->node->device_id);
287		return -1;
288	}
289
290	hs_req = (struct connection_handshake_req *)data;
291	req_len = le32_to_cpu(hs_req->len);
292	if (req_len > (data_len - sizeof(struct connection_handshake_req))) {
293		hmdfs_info(
294			"ack msg hs_req len(%u) error. data_len=%u, device_id=%llu",
295			req_len, data_len, conn_impl->node->device_id);
296		return -1;
297	}
298	extend_len =
299		data_len - sizeof(struct connection_handshake_req) - req_len;
300	extend_data = (uint8_t *)data +
301		      sizeof(struct connection_handshake_req) + req_len;
302	ret = hs_parse_extend_data(conn_impl, ops, extend_data, extend_len);
303	if (!ret)
304		hmdfs_info(
305			"hs msg rcv, ops=%u, data_len=%u, device_id=%llu, req_len=%u",
306			ops, data_len, conn_impl->node->device_id, hs_req->len);
307	return ret;
308}
309#ifdef CONFIG_HMDFS_FS_ENCRYPTION
310static int connection_handshake_init_tls(struct connection *conn_impl, __u8 ops)
311{
312	// init ktls config, use key1/key2 as init write-key of each direction
313	__u8 key1[HMDFS_KEY_SIZE];
314	__u8 key2[HMDFS_KEY_SIZE];
315	int ret;
316
317	if ((ops != CONNECT_MESG_HANDSHAKE_RESPONSE) &&
318	    (ops != CONNECT_MESG_HANDSHAKE_ACK)) {
319		hmdfs_err("ops %u is err", ops);
320		return -EINVAL;
321	}
322
323	update_key(conn_impl->master_key, key1, HKDF_TYPE_KEY_INITIATOR);
324	update_key(conn_impl->master_key, key2, HKDF_TYPE_KEY_ACCEPTER);
325
326	if (ops == CONNECT_MESG_HANDSHAKE_ACK) {
327		memcpy(conn_impl->send_key, key1, HMDFS_KEY_SIZE);
328		memcpy(conn_impl->recv_key, key2, HMDFS_KEY_SIZE);
329	} else {
330		memcpy(conn_impl->send_key, key2, HMDFS_KEY_SIZE);
331		memcpy(conn_impl->recv_key, key1, HMDFS_KEY_SIZE);
332	}
333
334	memset(key1, 0, HMDFS_KEY_SIZE);
335	memset(key2, 0, HMDFS_KEY_SIZE);
336
337	hmdfs_info("hs: ops=%u start set crypto tls", ops);
338	ret = tls_crypto_info_init(conn_impl);
339	if (ret)
340		hmdfs_err("setting tls fail. ops is %u", ops);
341
342	return ret;
343}
344#endif
345
346static int do_send_handshake(struct connection *conn_impl, __u8 ops,
347			     __le16 request_id)
348{
349	int err;
350	struct connection_msg_head *hs_head = NULL;
351	struct connection_handshake_req *hs_data = NULL;
352	uint8_t *hs_extend_data = NULL;
353	struct hmdfs_send_data msg;
354	__u32 send_len;
355	__u32 len;
356	__u32 extend_len;
357	char buf[HMDFS_CID_SIZE] = { 0 };
358
359	len = scnprintf(buf, HMDFS_CID_SIZE, "%llu", 0ULL);
360	send_len = sizeof(struct connection_msg_head) +
361		   sizeof(struct connection_handshake_req) + len;
362
363	if (((ops == CONNECT_MESG_HANDSHAKE_RESPONSE) ||
364	     (ops == CONNECT_MESG_HANDSHAKE_ACK))) {
365		extend_len = hs_get_extend_data_len();
366		send_len += extend_len;
367	}
368
369	hs_head = kzalloc(send_len, GFP_KERNEL);
370	if (!hs_head)
371		return -ENOMEM;
372
373	hs_data = (struct connection_handshake_req
374			   *)((uint8_t *)hs_head +
375			      sizeof(struct connection_msg_head));
376
377	hs_data->len = cpu_to_le32(len);
378	memcpy(hs_data->dev_id, buf, len);
379
380	if (((ops == CONNECT_MESG_HANDSHAKE_RESPONSE) ||
381	     ops == CONNECT_MESG_HANDSHAKE_ACK)) {
382		hs_extend_data = (uint8_t *)hs_data +
383				  sizeof(struct connection_handshake_req) + len;
384		hs_fill_extend_data(conn_impl, ops, hs_extend_data, extend_len);
385	}
386
387	hs_head->magic = HMDFS_MSG_MAGIC;
388	hs_head->version = HMDFS_VERSION;
389	hs_head->flags |= 0x1;
390	hmdfs_info("Send handshake message: ops = %d, fd = %d", ops,
391		   ((struct tcp_handle *)(conn_impl->connect_handle))->fd);
392	hs_head->operations = ops;
393	hs_head->request_id = request_id;
394	hs_head->datasize = cpu_to_le32(send_len);
395	hs_head->source = 0;
396	hs_head->msg_id = 0;
397
398	msg.head = hs_head;
399	msg.head_len = sizeof(struct connection_msg_head);
400	msg.data = hs_data;
401	msg.len = send_len - msg.head_len;
402	msg.sdesc = NULL;
403	msg.sdesc_len = 0;
404	err = conn_impl->send_message(conn_impl, &msg);
405	kfree(hs_head);
406	return err;
407}
408
409static int hmdfs_node_waiting_evt_sum(const struct hmdfs_peer *node)
410{
411	int sum = 0;
412	int i;
413
414	for (i = 0; i < RAW_NODE_EVT_NR; i++)
415		sum += node->waiting_evt[i];
416
417	return sum;
418}
419
420static int hmdfs_update_node_waiting_evt(struct hmdfs_peer *node, int evt,
421					 unsigned int *seq)
422{
423	int last;
424	int sum;
425	unsigned int next;
426
427	sum = hmdfs_node_waiting_evt_sum(node);
428	if (sum % RAW_NODE_EVT_NR)
429		last = !node->pending_evt;
430	else
431		last = node->pending_evt;
432
433	/* duplicated event */
434	if (evt == last) {
435		node->dup_evt[evt]++;
436		return 0;
437	}
438
439	node->waiting_evt[evt]++;
440	hmdfs_debug("add node->waiting_evt[%d]=%d", evt,
441		    node->waiting_evt[evt]);
442
443	/* offline wait + online wait + offline wait = offline wait
444	 * online wait + offline wait + online wait != online wait
445	 * As the first online related resource (e.g. fd) must be invalidated
446	 */
447	if (node->waiting_evt[RAW_NODE_EVT_OFF] >= 2 &&
448	    node->waiting_evt[RAW_NODE_EVT_ON] >= 1) {
449		node->waiting_evt[RAW_NODE_EVT_OFF] -= 1;
450		node->waiting_evt[RAW_NODE_EVT_ON] -= 1;
451		node->seq_wr_idx -= 2;
452		node->merged_evt += 2;
453	}
454
455	next = hmdfs_node_inc_evt_seq(node);
456	node->seq_tbl[(node->seq_wr_idx++) % RAW_NODE_EVT_MAX_NR] = next;
457	*seq = next;
458
459	return 1;
460}
461
462static void hmdfs_run_evt_cb_verbosely(struct hmdfs_peer *node, int raw_evt,
463				       bool sync, unsigned int seq)
464{
465	int evt = (raw_evt == RAW_NODE_EVT_OFF) ? NODE_EVT_OFFLINE :
466						  NODE_EVT_ONLINE;
467	int cur_evt_idx = sync ? 1 : 0;
468
469	node->cur_evt[cur_evt_idx] = raw_evt;
470	node->cur_evt_seq[cur_evt_idx] = seq;
471	hmdfs_node_call_evt_cb(node, evt, sync, seq);
472	node->cur_evt[cur_evt_idx] = RAW_NODE_EVT_NR;
473}
474
475static void hmdfs_node_evt_work(struct work_struct *work)
476{
477	struct hmdfs_peer *node =
478		container_of(work, struct hmdfs_peer, evt_dwork.work);
479	unsigned int seq;
480
481	/*
482	 * N-th sync cb completes before N-th async cb,
483	 * so use seq_lock as a barrier in read & write path
484	 * to ensure we can read the required seq.
485	 */
486	mutex_lock(&node->seq_lock);
487	seq = node->seq_tbl[(node->seq_rd_idx++) % RAW_NODE_EVT_MAX_NR];
488	hmdfs_run_evt_cb_verbosely(node, node->pending_evt, false, seq);
489	mutex_unlock(&node->seq_lock);
490
491	mutex_lock(&node->evt_lock);
492	if (hmdfs_node_waiting_evt_sum(node)) {
493		node->pending_evt = !node->pending_evt;
494		node->pending_evt_seq =
495			node->seq_tbl[node->seq_rd_idx % RAW_NODE_EVT_MAX_NR];
496		node->waiting_evt[node->pending_evt]--;
497		/* sync cb has been done */
498		schedule_delayed_work(&node->evt_dwork,
499				      node->sbi->async_cb_delay * HZ);
500	} else {
501		node->last_evt = node->pending_evt;
502		node->pending_evt = RAW_NODE_EVT_NR;
503	}
504	mutex_unlock(&node->evt_lock);
505}
506
507/*
508 * The running orders of cb are:
509 *
510 * (1) sync callbacks are invoked according to the queue order of raw events:
511 *     ensured by seq_lock.
512 * (2) async callbacks are invoked according to the queue order of raw events:
513 *     ensured by evt_lock & evt_dwork
514 * (3) async callback is invoked after sync callback of the same raw event:
515 *     ensured by seq_lock.
516 * (4) async callback of N-th raw event and sync callback of (N+x)-th raw
517 *     event can run concurrently.
518 */
519static void hmdfs_queue_raw_node_evt(struct hmdfs_peer *node, int evt)
520{
521	unsigned int seq = 0;
522
523	mutex_lock(&node->evt_lock);
524	if (node->pending_evt == RAW_NODE_EVT_NR) {
525		if (evt == node->last_evt) {
526			node->dup_evt[evt]++;
527			mutex_unlock(&node->evt_lock);
528			return;
529		}
530		node->pending_evt = evt;
531		seq = hmdfs_node_inc_evt_seq(node);
532		node->seq_tbl[(node->seq_wr_idx++) % RAW_NODE_EVT_MAX_NR] = seq;
533		node->pending_evt_seq = seq;
534		mutex_lock(&node->seq_lock);
535		mutex_unlock(&node->evt_lock);
536		/* call sync cb, then async cb */
537		hmdfs_run_evt_cb_verbosely(node, evt, true, seq);
538		mutex_unlock(&node->seq_lock);
539		schedule_delayed_work(&node->evt_dwork,
540				      node->sbi->async_cb_delay * HZ);
541	} else if (hmdfs_update_node_waiting_evt(node, evt, &seq) > 0) {
542		/*
543		 * Take seq_lock firstly to ensure N-th sync cb
544		 * is called before N-th async cb.
545		 */
546		mutex_lock(&node->seq_lock);
547		mutex_unlock(&node->evt_lock);
548		hmdfs_run_evt_cb_verbosely(node, evt, true, seq);
549		mutex_unlock(&node->seq_lock);
550	} else {
551		mutex_unlock(&node->evt_lock);
552	}
553}
554
555void connection_send_handshake(struct connection *conn_impl, __u8 ops,
556			       __le16 request_id)
557{
558	struct tcp_handle *tcp = NULL;
559	int err = do_send_handshake(conn_impl, ops, request_id);
560
561	if (likely(err >= 0))
562		return;
563
564	tcp = conn_impl->connect_handle;
565	hmdfs_err("Failed to send handshake: err = %d, fd = %d", err, tcp->fd);
566	hmdfs_reget_connection(conn_impl);
567}
568
569void connection_handshake_notify(struct hmdfs_peer *node, int notify_type)
570{
571	struct notify_param param;
572
573	param.notify = notify_type;
574	param.fd = INVALID_SOCKET_FD;
575	memcpy(param.remote_cid, node->cid, HMDFS_CID_SIZE);
576	notify(node, &param);
577}
578
579
580void peer_online(struct hmdfs_peer *peer)
581{
582	// To evaluate if someone else has made the peer online
583	u8 prev_stat = xchg(&peer->status, NODE_STAT_ONLINE);
584	unsigned long jif_tmp = jiffies;
585
586	if (prev_stat == NODE_STAT_ONLINE)
587		return;
588	WRITE_ONCE(peer->conn_time, jif_tmp);
589	WRITE_ONCE(peer->sbi->connections.recent_ol, jif_tmp);
590	hmdfs_queue_raw_node_evt(peer, RAW_NODE_EVT_ON);
591}
592
593void connection_to_working(struct hmdfs_peer *node)
594{
595	struct connection *conn_impl = NULL;
596	struct tcp_handle *tcp = NULL;
597
598	if (!node)
599		return;
600	mutex_lock(&node->conn_impl_list_lock);
601	list_for_each_entry(conn_impl, &node->conn_impl_list, list) {
602		if (conn_impl->type == CONNECT_TYPE_TCP &&
603		    conn_impl->status == CONNECT_STAT_WAIT_RESPONSE) {
604			tcp = conn_impl->connect_handle;
605			hmdfs_info("fd %d to working", tcp->fd);
606			conn_impl->status = CONNECT_STAT_WORKING;
607		}
608	}
609	mutex_unlock(&node->conn_impl_list_lock);
610	peer_online(node);
611}
612
613void connection_handshake_recv_handler(struct connection *conn_impl, void *buf,
614				       void *data, __u32 data_len)
615{
616	__u8 ops;
617	__u8 status;
618	int fd = ((struct tcp_handle *)(conn_impl->connect_handle))->fd;
619	struct connection_msg_head *head = (struct connection_msg_head *)buf;
620	int ret;
621
622	if (head->version != HMDFS_VERSION)
623		goto out;
624
625	conn_impl->node->version = head->version;
626	ops = head->operations;
627	status = conn_impl->status;
628	switch (ops) {
629	case CONNECT_MESG_HANDSHAKE_REQUEST:
630		hmdfs_info(
631			"Recved handshake request: device_id = %llu, head->len = %d, tcp->fd = %d",
632			conn_impl->node->device_id, head->datasize, fd);
633		connection_send_handshake(conn_impl,
634					  CONNECT_MESG_HANDSHAKE_RESPONSE,
635					  head->msg_id);
636		conn_impl->status = CONNECT_STAT_WAIT_ACK;
637		conn_impl->node->status = NODE_STAT_SHAKING;
638		break;
639	case CONNECT_MESG_HANDSHAKE_RESPONSE:
640		hmdfs_info(
641			"Recved handshake response: device_id = %llu, cmd->status = %hhu, tcp->fd = %d",
642			conn_impl->node->device_id, status, fd);
643
644		ret = hs_proc_msg_data(conn_impl, ops, data, data_len);
645		if (ret)
646			goto nego_err;
647		connection_send_handshake(conn_impl,
648					  CONNECT_MESG_HANDSHAKE_ACK,
649					  head->msg_id);
650		hmdfs_info("respon rcv handle,conn_impl->crypto=0x%0x",
651				conn_impl->crypto);
652#ifdef CONFIG_HMDFS_FS_ENCRYPTION
653		ret = connection_handshake_init_tls(conn_impl, ops);
654		if (ret) {
655			hmdfs_err("init_tls_key fail, ops %u", ops);
656			goto out;
657		}
658#endif
659
660		conn_impl->status = CONNECT_STAT_WORKING;
661		peer_online(conn_impl->node);
662		break;
663	case CONNECT_MESG_HANDSHAKE_ACK:
664		ret = hs_proc_msg_data(conn_impl, ops, data, data_len);
665		if (ret)
666			goto nego_err;
667		hmdfs_info("ack rcv handle, conn_impl->crypto=0x%0x",
668				conn_impl->crypto);
669#ifdef CONFIG_HMDFS_FS_ENCRYPTION
670		ret = connection_handshake_init_tls(conn_impl, ops);
671		if (ret) {
672			hmdfs_err("init_tls_key fail, ops %u", ops);
673			goto out;
674		}
675#endif
676		conn_impl->status = CONNECT_STAT_WORKING;
677		peer_online(conn_impl->node);
678		break;
679		fallthrough;
680	default:
681		break;
682	}
683out:
684	kfree(data);
685	return;
686nego_err:
687	conn_impl->status = CONNECT_STAT_NEGO_FAIL;
688	connection_handshake_notify(conn_impl->node, NOTIFY_OFFLINE);
689	hmdfs_err("protocol negotiation failed, remote device_id = %llu, tcp->fd = %d",
690		  conn_impl->node->device_id, fd);
691	goto out;
692}
693
694#ifdef CONFIG_HMDFS_FS_ENCRYPTION
695static void update_tls_crypto_key(struct connection *conn,
696				  struct hmdfs_head_cmd *head, void *data,
697				  __u32 data_len)
698{
699	// rekey message handler
700	struct connection_rekey_request *rekey_req = NULL;
701	int ret = 0;
702
703	if (hmdfs_message_verify(conn->node, head, data) < 0) {
704		hmdfs_err("Rekey msg %d has been abandoned", head->msg_id);
705		goto out_err;
706	}
707
708	hmdfs_info("recv REKEY request");
709	set_crypto_info(conn, SET_CRYPTO_RECV);
710	// update send key if requested
711	rekey_req = data;
712	if (le32_to_cpu(rekey_req->update_request) == UPDATE_REQUESTED) {
713		ret = tcp_send_rekey_request(conn);
714		if (ret == 0)
715			set_crypto_info(conn, SET_CRYPTO_SEND);
716	}
717out_err:
718	kfree(data);
719}
720
721static bool cmd_update_tls_crypto_key(struct connection *conn,
722				      struct hmdfs_head_cmd *head)
723{
724	struct tcp_handle *tcp = conn->connect_handle;
725
726	if (conn->type != CONNECT_TYPE_TCP || !tcp)
727		return false;
728	return head->operations.command == F_CONNECT_REKEY;
729}
730#endif
731
732void connection_working_recv_handler(struct connection *conn_impl, void *buf,
733				     void *data, __u32 data_len)
734{
735#ifdef CONFIG_HMDFS_FS_ENCRYPTION
736	if (cmd_update_tls_crypto_key(conn_impl, buf)) {
737		update_tls_crypto_key(conn_impl, buf, data, data_len);
738		return;
739	}
740#endif
741	hmdfs_recv_mesg_callback(conn_impl->node, buf, data);
742}
743
744static void connection_release(struct kref *ref)
745{
746	struct tcp_handle *tcp = NULL;
747	struct connection *conn = container_of(ref, struct connection, ref_cnt);
748
749	hmdfs_info("connection release");
750	memset(conn->master_key, 0, HMDFS_KEY_SIZE);
751	memset(conn->send_key, 0, HMDFS_KEY_SIZE);
752	memset(conn->recv_key, 0, HMDFS_KEY_SIZE);
753	if (conn->close)
754		conn->close(conn);
755	tcp = conn->connect_handle;
756	crypto_free_aead(conn->tfm);
757	// need to check and test: fput(tcp->sock->file);
758	if (tcp && tcp->sock) {
759		hmdfs_info("connection release: fd = %d, refcount %ld", tcp->fd,
760			   file_count(tcp->sock->file));
761		sockfd_put(tcp->sock);
762	}
763	if (tcp && tcp->recv_cache)
764		kmem_cache_destroy(tcp->recv_cache);
765
766	if (!list_empty(&conn->list)) {
767		mutex_lock(&conn->node->conn_impl_list_lock);
768		list_del(&conn->list);
769		mutex_unlock(&conn->node->conn_impl_list_lock);
770		/*
771		 * wakup hmdfs_disconnect_node to check
772		 * conn_deleting_list if empty.
773		 */
774		wake_up_interruptible(&conn->node->deleting_list_wq);
775	}
776
777	kfree(tcp);
778	kfree(conn);
779}
780
781static void hmdfs_peer_release(struct kref *ref)
782{
783	struct hmdfs_peer *peer = container_of(ref, struct hmdfs_peer, ref_cnt);
784	struct mutex *lock = &peer->sbi->connections.node_lock;
785
786	if (!list_empty(&peer->list))
787		hmdfs_info("releasing a on-sbi peer: device_id %llu ",
788			   peer->device_id);
789	else
790		hmdfs_info("releasing a redundant peer: device_id %llu ",
791			   peer->device_id);
792
793	cancel_delayed_work_sync(&peer->evt_dwork);
794	list_del(&peer->list);
795	idr_destroy(&peer->msg_idr);
796	idr_destroy(&peer->file_id_idr);
797	flush_workqueue(peer->req_handle_wq);
798	flush_workqueue(peer->async_wq);
799	flush_workqueue(peer->retry_wb_wq);
800	destroy_workqueue(peer->dentry_wq);
801	destroy_workqueue(peer->req_handle_wq);
802	destroy_workqueue(peer->async_wq);
803	destroy_workqueue(peer->retry_wb_wq);
804	destroy_workqueue(peer->reget_conn_wq);
805	kfree(peer);
806	mutex_unlock(lock);
807}
808
809void connection_put(struct connection *conn)
810{
811	struct mutex *lock = &conn->ref_lock;
812
813	kref_put_mutex(&conn->ref_cnt, connection_release, lock);
814}
815
816void peer_put(struct hmdfs_peer *peer)
817{
818	struct mutex *lock = &peer->sbi->connections.node_lock;
819
820	kref_put_mutex(&peer->ref_cnt, hmdfs_peer_release, lock);
821}
822
823static void hmdfs_dump_deleting_list(struct hmdfs_peer *node)
824{
825	struct connection *con = NULL;
826	struct tcp_handle *tcp = NULL;
827	int count = 0;
828
829	mutex_lock(&node->conn_impl_list_lock);
830	list_for_each_entry(con, &node->conn_deleting_list, list) {
831		tcp = con->connect_handle;
832		hmdfs_info("deleting list %d:device_id %llu tcp_fd %d refcnt %d",
833			   count, node->device_id, tcp ? tcp->fd : -1,
834			   kref_read(&con->ref_cnt));
835		count++;
836	}
837	mutex_unlock(&node->conn_impl_list_lock);
838}
839
840static bool hmdfs_conn_deleting_list_empty(struct hmdfs_peer *node)
841{
842	bool empty = false;
843
844	mutex_lock(&node->conn_impl_list_lock);
845	empty = list_empty(&node->conn_deleting_list);
846	mutex_unlock(&node->conn_impl_list_lock);
847
848	return empty;
849}
850
851void hmdfs_disconnect_node(struct hmdfs_peer *node)
852{
853	LIST_HEAD(local_conns);
854	struct connection *conn_impl = NULL;
855	struct connection *next = NULL;
856	struct tcp_handle *tcp = NULL;
857
858	if (unlikely(!node))
859		return;
860
861	hmdfs_node_inc_evt_seq(node);
862	/* Refer to comments in hmdfs_is_node_offlined() */
863	smp_mb__after_atomic();
864	node->status = NODE_STAT_OFFLINE;
865	hmdfs_info("Try to disconnect peer: device_id %llu", node->device_id);
866
867	mutex_lock(&node->conn_impl_list_lock);
868	if (!list_empty(&node->conn_impl_list))
869		list_replace_init(&node->conn_impl_list, &local_conns);
870	mutex_unlock(&node->conn_impl_list_lock);
871
872	list_for_each_entry_safe(conn_impl, next, &local_conns, list) {
873		tcp = conn_impl->connect_handle;
874		if (tcp && tcp->sock) {
875			kernel_sock_shutdown(tcp->sock, SHUT_RDWR);
876			hmdfs_info("shudown sock: fd = %d, refcount %ld",
877				   tcp->fd, file_count(tcp->sock->file));
878		}
879		if (tcp)
880			tcp->fd = INVALID_SOCKET_FD;
881
882		tcp_close_socket(tcp);
883		list_del_init(&conn_impl->list);
884
885		connection_put(conn_impl);
886	}
887
888	if (wait_event_interruptible_timeout(node->deleting_list_wq,
889					hmdfs_conn_deleting_list_empty(node),
890					HMDFS_WAIT_CONN_RELEASE) <= 0)
891		hmdfs_dump_deleting_list(node);
892
893	/* wait all request process end */
894	spin_lock(&node->idr_lock);
895	while (node->msg_idr_process) {
896		spin_unlock(&node->idr_lock);
897		usleep_range(HMDFS_WAIT_REQUEST_END_MIN,
898			     HMDFS_WAIT_REQUEST_END_MAX);
899		spin_lock(&node->idr_lock);
900	}
901	spin_unlock(&node->idr_lock);
902
903	hmdfs_queue_raw_node_evt(node, RAW_NODE_EVT_OFF);
904}
905
906static void hmdfs_run_simple_evt_cb(struct hmdfs_peer *node, int evt)
907{
908	unsigned int seq = hmdfs_node_inc_evt_seq(node);
909
910	mutex_lock(&node->seq_lock);
911	hmdfs_node_call_evt_cb(node, evt, true, seq);
912	mutex_unlock(&node->seq_lock);
913}
914
915static void hmdfs_del_peer(struct hmdfs_peer *node)
916{
917	/*
918	 * No need for offline evt cb, because all files must
919	 * have been flushed and closed, else the filesystem
920	 * will be un-mountable.
921	 */
922	cancel_delayed_work_sync(&node->evt_dwork);
923
924	hmdfs_run_simple_evt_cb(node, NODE_EVT_DEL);
925
926	hmdfs_release_peer_sysfs(node);
927
928	flush_workqueue(node->reget_conn_wq);
929	peer_put(node);
930}
931
932void hmdfs_connections_stop(struct hmdfs_sb_info *sbi)
933{
934	struct hmdfs_peer *node = NULL;
935	struct hmdfs_peer *con_tmp = NULL;
936
937	mutex_lock(&sbi->connections.node_lock);
938	list_for_each_entry_safe(node, con_tmp, &sbi->connections.node_list,
939				  list) {
940		mutex_unlock(&sbi->connections.node_lock);
941		hmdfs_disconnect_node(node);
942		hmdfs_del_peer(node);
943		mutex_lock(&sbi->connections.node_lock);
944	}
945	mutex_unlock(&sbi->connections.node_lock);
946}
947
948struct connection *get_conn_impl(struct hmdfs_peer *node, int connect_type)
949{
950	struct connection *conn_impl = NULL;
951
952	if (!node)
953		return NULL;
954	mutex_lock(&node->conn_impl_list_lock);
955	list_for_each_entry(conn_impl, &node->conn_impl_list, list) {
956		if (conn_impl->type == connect_type &&
957		    conn_impl->status == CONNECT_STAT_WORKING) {
958			connection_get(conn_impl);
959			mutex_unlock(&node->conn_impl_list_lock);
960			return conn_impl;
961		}
962	}
963	mutex_unlock(&node->conn_impl_list_lock);
964	hmdfs_err_ratelimited("device %llu not find connection, type %d",
965			      node->device_id, connect_type);
966	return NULL;
967}
968
969void set_conn_sock_quickack(struct hmdfs_peer *node)
970{
971	struct connection *conn_impl = NULL;
972	struct tcp_handle *tcp = NULL;
973	int option = 1;
974
975	if (!node)
976		return;
977	mutex_lock(&node->conn_impl_list_lock);
978	list_for_each_entry(conn_impl, &node->conn_impl_list, list) {
979		if (conn_impl->type == CONNECT_TYPE_TCP &&
980		    conn_impl->status == CONNECT_STAT_WORKING &&
981		    conn_impl->connect_handle) {
982			tcp = (struct tcp_handle *)(conn_impl->connect_handle);
983			tcp_sock_set_quickack(tcp->sock->sk, option);
984		}
985	}
986	mutex_unlock(&node->conn_impl_list_lock);
987}
988
989struct hmdfs_peer *hmdfs_lookup_from_devid(struct hmdfs_sb_info *sbi,
990					   uint64_t device_id)
991{
992	struct hmdfs_peer *con = NULL;
993	struct hmdfs_peer *lookup = NULL;
994
995	if (!sbi)
996		return NULL;
997	mutex_lock(&sbi->connections.node_lock);
998	list_for_each_entry(con, &sbi->connections.node_list, list) {
999		if (con->status != NODE_STAT_ONLINE ||
1000		    con->device_id != device_id)
1001			continue;
1002		lookup = con;
1003		peer_get(lookup);
1004		break;
1005	}
1006	mutex_unlock(&sbi->connections.node_lock);
1007	return lookup;
1008}
1009
1010struct hmdfs_peer *hmdfs_lookup_from_cid(struct hmdfs_sb_info *sbi,
1011					 uint8_t *cid)
1012{
1013	struct hmdfs_peer *con = NULL;
1014	struct hmdfs_peer *lookup = NULL;
1015
1016	if (!sbi)
1017		return NULL;
1018	mutex_lock(&sbi->connections.node_lock);
1019	list_for_each_entry(con, &sbi->connections.node_list, list) {
1020		if (strncmp(con->cid, cid, HMDFS_CID_SIZE) != 0)
1021			continue;
1022		lookup = con;
1023		peer_get(lookup);
1024		break;
1025	}
1026	mutex_unlock(&sbi->connections.node_lock);
1027	return lookup;
1028}
1029
1030static struct hmdfs_peer *lookup_peer_by_cid_unsafe(struct hmdfs_sb_info *sbi,
1031						    uint8_t *cid)
1032{
1033	struct hmdfs_peer *node = NULL;
1034
1035	list_for_each_entry(node, &sbi->connections.node_list, list)
1036		if (!strncmp(node->cid, cid, HMDFS_CID_SIZE)) {
1037			peer_get(node);
1038			return node;
1039		}
1040	return NULL;
1041}
1042
1043static struct hmdfs_peer *add_peer_unsafe(struct hmdfs_sb_info *sbi,
1044					  struct hmdfs_peer *peer2add)
1045{
1046	struct hmdfs_peer *peer;
1047	int err;
1048
1049	peer = lookup_peer_by_cid_unsafe(sbi, peer2add->cid);
1050	if (peer)
1051		return peer;
1052
1053	err = hmdfs_register_peer_sysfs(sbi, peer2add);
1054	if (err) {
1055		hmdfs_err("register peer %llu sysfs err %d",
1056			  peer2add->device_id, err);
1057		return ERR_PTR(err);
1058	}
1059	list_add_tail(&peer2add->list, &sbi->connections.node_list);
1060	peer_get(peer2add);
1061	hmdfs_run_simple_evt_cb(peer2add, NODE_EVT_ADD);
1062	return peer2add;
1063}
1064
1065static struct hmdfs_peer *alloc_peer(struct hmdfs_sb_info *sbi, uint8_t *cid,
1066	uint32_t devsl)
1067{
1068	struct hmdfs_peer *node = kzalloc(sizeof(*node), GFP_KERNEL);
1069
1070	if (!node)
1071		return NULL;
1072
1073	node->device_id = (u32)atomic_inc_return(&sbi->connections.conn_seq);
1074
1075	node->async_wq = alloc_workqueue("dfs_async%u_%llu", WQ_MEM_RECLAIM, 0,
1076					 sbi->seq, node->device_id);
1077	if (!node->async_wq) {
1078		hmdfs_err("Failed to alloc async wq");
1079		goto out_err;
1080	}
1081	node->req_handle_wq = alloc_workqueue("dfs_req%u_%llu",
1082					      WQ_UNBOUND | WQ_MEM_RECLAIM,
1083					      sbi->async_req_max_active,
1084					      sbi->seq, node->device_id);
1085	if (!node->req_handle_wq) {
1086		hmdfs_err("Failed to alloc req wq");
1087		goto out_err;
1088	}
1089	node->dentry_wq = alloc_workqueue("dfs_dentry%u_%llu",
1090					   WQ_UNBOUND | WQ_MEM_RECLAIM,
1091					   0, sbi->seq, node->device_id);
1092	if (!node->dentry_wq) {
1093		hmdfs_err("Failed to alloc dentry wq");
1094		goto out_err;
1095	}
1096	node->retry_wb_wq = alloc_workqueue("dfs_rwb%u_%llu",
1097					   WQ_UNBOUND | WQ_MEM_RECLAIM,
1098					   HMDFS_RETRY_WB_WQ_MAX_ACTIVE,
1099					   sbi->seq, node->device_id);
1100	if (!node->retry_wb_wq) {
1101		hmdfs_err("Failed to alloc retry writeback wq");
1102		goto out_err;
1103	}
1104	node->reget_conn_wq = alloc_workqueue("dfs_regetcon%u_%llu",
1105					      WQ_UNBOUND, 0,
1106					      sbi->seq, node->device_id);
1107	if (!node->reget_conn_wq) {
1108		hmdfs_err("Failed to alloc reget conn wq");
1109		goto out_err;
1110	}
1111	INIT_LIST_HEAD(&node->conn_impl_list);
1112	mutex_init(&node->conn_impl_list_lock);
1113	INIT_LIST_HEAD(&node->conn_deleting_list);
1114	init_waitqueue_head(&node->deleting_list_wq);
1115	idr_init(&node->msg_idr);
1116	spin_lock_init(&node->idr_lock);
1117	idr_init(&node->file_id_idr);
1118	spin_lock_init(&node->file_id_lock);
1119	INIT_LIST_HEAD(&node->list);
1120	kref_init(&node->ref_cnt);
1121	node->owner = sbi->seq;
1122	node->sbi = sbi;
1123	node->version = HMDFS_VERSION;
1124	node->status = NODE_STAT_SHAKING;
1125	node->conn_time = jiffies;
1126	memcpy(node->cid, cid, HMDFS_CID_SIZE);
1127	atomic64_set(&node->sb_dirty_count, 0);
1128	node->fid_cookie = 0;
1129	atomic_set(&node->evt_seq, 0);
1130	mutex_init(&node->seq_lock);
1131	mutex_init(&node->offline_cb_lock);
1132	mutex_init(&node->evt_lock);
1133	node->pending_evt = RAW_NODE_EVT_NR;
1134	node->last_evt = RAW_NODE_EVT_NR;
1135	node->cur_evt[0] = RAW_NODE_EVT_NR;
1136	node->cur_evt[1] = RAW_NODE_EVT_NR;
1137	node->seq_wr_idx = (unsigned char)UINT_MAX;
1138	node->seq_rd_idx = node->seq_wr_idx;
1139	INIT_DELAYED_WORK(&node->evt_dwork, hmdfs_node_evt_work);
1140	node->msg_idr_process = 0;
1141	node->offline_start = false;
1142	spin_lock_init(&node->wr_opened_inode_lock);
1143	INIT_LIST_HEAD(&node->wr_opened_inode_list);
1144	spin_lock_init(&node->stashed_inode_lock);
1145	node->stashed_inode_nr = 0;
1146	atomic_set(&node->rebuild_inode_status_nr, 0);
1147	init_waitqueue_head(&node->rebuild_inode_status_wq);
1148	INIT_LIST_HEAD(&node->stashed_inode_list);
1149	node->need_rebuild_stash_list = false;
1150	node->devsl = devsl;
1151
1152	return node;
1153
1154out_err:
1155	if (node->async_wq) {
1156		destroy_workqueue(node->async_wq);
1157		node->async_wq = NULL;
1158	}
1159	if (node->req_handle_wq) {
1160		destroy_workqueue(node->req_handle_wq);
1161		node->req_handle_wq = NULL;
1162	}
1163	if (node->dentry_wq) {
1164		destroy_workqueue(node->dentry_wq);
1165		node->dentry_wq = NULL;
1166	}
1167	if (node->retry_wb_wq) {
1168		destroy_workqueue(node->retry_wb_wq);
1169		node->retry_wb_wq = NULL;
1170	}
1171	if (node->reget_conn_wq) {
1172		destroy_workqueue(node->reget_conn_wq);
1173		node->reget_conn_wq = NULL;
1174	}
1175	kfree(node);
1176	return NULL;
1177}
1178
1179struct hmdfs_peer *hmdfs_get_peer(struct hmdfs_sb_info *sbi, uint8_t *cid,
1180	uint32_t devsl)
1181{
1182	struct hmdfs_peer *peer = NULL, *on_sbi_peer = NULL;
1183
1184	mutex_lock(&sbi->connections.node_lock);
1185	peer = lookup_peer_by_cid_unsafe(sbi, cid);
1186	mutex_unlock(&sbi->connections.node_lock);
1187	if (peer) {
1188		hmdfs_info("Got a existing peer: device_id = %llu",
1189			   peer->device_id);
1190		goto out;
1191	}
1192
1193	peer = alloc_peer(sbi, cid, devsl);
1194	if (unlikely(!peer)) {
1195		hmdfs_info("Failed to alloc a peer");
1196		goto out;
1197	}
1198
1199	mutex_lock(&sbi->connections.node_lock);
1200	on_sbi_peer = add_peer_unsafe(sbi, peer);
1201	mutex_unlock(&sbi->connections.node_lock);
1202	if (IS_ERR(on_sbi_peer)) {
1203		peer_put(peer);
1204		peer = NULL;
1205		goto out;
1206	} else if (unlikely(on_sbi_peer != peer)) {
1207		hmdfs_info("Got a existing peer: device_id = %llu",
1208			   on_sbi_peer->device_id);
1209		peer_put(peer);
1210		peer = on_sbi_peer;
1211	} else {
1212		hmdfs_info("Got a newly allocated peer: device_id = %llu",
1213			   peer->device_id);
1214	}
1215
1216out:
1217	return peer;
1218}
1219
1220static void head_release(struct kref *kref)
1221{
1222	struct hmdfs_msg_idr_head *head;
1223	struct hmdfs_peer *con;
1224
1225	head = (struct hmdfs_msg_idr_head *)container_of(kref,
1226			struct hmdfs_msg_idr_head, ref);
1227	con = head->peer;
1228	idr_remove(&con->msg_idr, head->msg_id);
1229	spin_unlock(&con->idr_lock);
1230
1231	kfree(head);
1232}
1233
1234void head_put(struct hmdfs_msg_idr_head *head)
1235{
1236	kref_put_lock(&head->ref, head_release, &head->peer->idr_lock);
1237}
1238
1239struct hmdfs_msg_idr_head *hmdfs_find_msg_head(struct hmdfs_peer *peer,
1240					int id, struct hmdfs_cmd operations)
1241{
1242	struct hmdfs_msg_idr_head *head = NULL;
1243
1244	spin_lock(&peer->idr_lock);
1245	head = idr_find(&peer->msg_idr, id);
1246	if (head && head->send_cmd_operations.command == operations.command)
1247		kref_get(&head->ref);
1248	else
1249		head = NULL;
1250	spin_unlock(&peer->idr_lock);
1251
1252	return head;
1253}
1254
1255int hmdfs_alloc_msg_idr(struct hmdfs_peer *peer, enum MSG_IDR_TYPE type,
1256			void *ptr, struct hmdfs_cmd operations)
1257{
1258	int ret = -EAGAIN;
1259	struct hmdfs_msg_idr_head *head = ptr;
1260
1261	idr_preload(GFP_KERNEL);
1262	spin_lock(&peer->idr_lock);
1263	if (!peer->offline_start)
1264		ret = idr_alloc_cyclic(&peer->msg_idr, ptr,
1265				       1, 0, GFP_NOWAIT);
1266	if (ret >= 0) {
1267		kref_init(&head->ref);
1268		head->msg_id = ret;
1269		head->type = type;
1270		head->peer = peer;
1271		head->send_cmd_operations = operations;
1272		peer->msg_idr_process++;
1273		ret = 0;
1274	}
1275	spin_unlock(&peer->idr_lock);
1276	idr_preload_end();
1277
1278	return ret;
1279}
1280