xref: /kernel/linux/linux-5.10/net/ipv6/seg6_local.c (revision 8c2ecf20)
1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 *  SR-IPv6 implementation
4 *
5 *  Authors:
6 *  David Lebrun <david.lebrun@uclouvain.be>
7 *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
8 */
9
10#include <linux/types.h>
11#include <linux/skbuff.h>
12#include <linux/net.h>
13#include <linux/module.h>
14#include <net/ip.h>
15#include <net/lwtunnel.h>
16#include <net/netevent.h>
17#include <net/netns/generic.h>
18#include <net/ip6_fib.h>
19#include <net/route.h>
20#include <net/seg6.h>
21#include <linux/seg6.h>
22#include <linux/seg6_local.h>
23#include <net/addrconf.h>
24#include <net/ip6_route.h>
25#include <net/dst_cache.h>
26#include <net/ip_tunnels.h>
27#ifdef CONFIG_IPV6_SEG6_HMAC
28#include <net/seg6_hmac.h>
29#endif
30#include <net/seg6_local.h>
31#include <linux/etherdevice.h>
32#include <linux/bpf.h>
33
34struct seg6_local_lwt;
35
36struct seg6_action_desc {
37	int action;
38	unsigned long attrs;
39	int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
40	int static_headroom;
41};
42
43struct bpf_lwt_prog {
44	struct bpf_prog *prog;
45	char *name;
46};
47
48struct seg6_local_lwt {
49	int action;
50	struct ipv6_sr_hdr *srh;
51	int table;
52	struct in_addr nh4;
53	struct in6_addr nh6;
54	int iif;
55	int oif;
56	struct bpf_lwt_prog bpf;
57
58	int headroom;
59	struct seg6_action_desc *desc;
60};
61
62static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
63{
64	return (struct seg6_local_lwt *)lwt->data;
65}
66
67static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb)
68{
69	struct ipv6_sr_hdr *srh;
70	int len, srhoff = 0;
71
72	if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0)
73		return NULL;
74
75	if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
76		return NULL;
77
78	srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
79
80	len = (srh->hdrlen + 1) << 3;
81
82	if (!pskb_may_pull(skb, srhoff + len))
83		return NULL;
84
85	/* note that pskb_may_pull may change pointers in header;
86	 * for this reason it is necessary to reload them when needed.
87	 */
88	srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
89
90	if (!seg6_validate_srh(srh, len, true))
91		return NULL;
92
93	return srh;
94}
95
96static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
97{
98	struct ipv6_sr_hdr *srh;
99
100	srh = get_srh(skb);
101	if (!srh)
102		return NULL;
103
104	if (srh->segments_left == 0)
105		return NULL;
106
107#ifdef CONFIG_IPV6_SEG6_HMAC
108	if (!seg6_hmac_validate_skb(skb))
109		return NULL;
110#endif
111
112	return srh;
113}
114
115static bool decap_and_validate(struct sk_buff *skb, int proto)
116{
117	struct ipv6_sr_hdr *srh;
118	unsigned int off = 0;
119
120	srh = get_srh(skb);
121	if (srh && srh->segments_left > 0)
122		return false;
123
124#ifdef CONFIG_IPV6_SEG6_HMAC
125	if (srh && !seg6_hmac_validate_skb(skb))
126		return false;
127#endif
128
129	if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
130		return false;
131
132	if (!pskb_pull(skb, off))
133		return false;
134
135	skb_postpull_rcsum(skb, skb_network_header(skb), off);
136
137	skb_reset_network_header(skb);
138	skb_reset_transport_header(skb);
139	if (iptunnel_pull_offloads(skb))
140		return false;
141
142	return true;
143}
144
145static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
146{
147	struct in6_addr *addr;
148
149	srh->segments_left--;
150	addr = srh->segments + srh->segments_left;
151	*daddr = *addr;
152}
153
154static int
155seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
156			u32 tbl_id, bool local_delivery)
157{
158	struct net *net = dev_net(skb->dev);
159	struct ipv6hdr *hdr = ipv6_hdr(skb);
160	int flags = RT6_LOOKUP_F_HAS_SADDR;
161	struct dst_entry *dst = NULL;
162	struct rt6_info *rt;
163	struct flowi6 fl6;
164	int dev_flags = 0;
165
166	fl6.flowi6_iif = skb->dev->ifindex;
167	fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
168	fl6.saddr = hdr->saddr;
169	fl6.flowlabel = ip6_flowinfo(hdr);
170	fl6.flowi6_mark = skb->mark;
171	fl6.flowi6_proto = hdr->nexthdr;
172
173	if (nhaddr)
174		fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
175
176	if (!tbl_id) {
177		dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
178	} else {
179		struct fib6_table *table;
180
181		table = fib6_get_table(net, tbl_id);
182		if (!table)
183			goto out;
184
185		rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
186		dst = &rt->dst;
187	}
188
189	/* we want to discard traffic destined for local packet processing,
190	 * if @local_delivery is set to false.
191	 */
192	if (!local_delivery)
193		dev_flags |= IFF_LOOPBACK;
194
195	if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
196		dst_release(dst);
197		dst = NULL;
198	}
199
200out:
201	if (!dst) {
202		rt = net->ipv6.ip6_blk_hole_entry;
203		dst = &rt->dst;
204		dst_hold(dst);
205	}
206
207	skb_dst_drop(skb);
208	skb_dst_set(skb, dst);
209	return dst->error;
210}
211
212int seg6_lookup_nexthop(struct sk_buff *skb,
213			struct in6_addr *nhaddr, u32 tbl_id)
214{
215	return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
216}
217
218/* regular endpoint function */
219static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
220{
221	struct ipv6_sr_hdr *srh;
222
223	srh = get_and_validate_srh(skb);
224	if (!srh)
225		goto drop;
226
227	advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
228
229	seg6_lookup_nexthop(skb, NULL, 0);
230
231	return dst_input(skb);
232
233drop:
234	kfree_skb(skb);
235	return -EINVAL;
236}
237
238/* regular endpoint, and forward to specified nexthop */
239static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
240{
241	struct ipv6_sr_hdr *srh;
242
243	srh = get_and_validate_srh(skb);
244	if (!srh)
245		goto drop;
246
247	advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
248
249	seg6_lookup_nexthop(skb, &slwt->nh6, 0);
250
251	return dst_input(skb);
252
253drop:
254	kfree_skb(skb);
255	return -EINVAL;
256}
257
258static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
259{
260	struct ipv6_sr_hdr *srh;
261
262	srh = get_and_validate_srh(skb);
263	if (!srh)
264		goto drop;
265
266	advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
267
268	seg6_lookup_nexthop(skb, NULL, slwt->table);
269
270	return dst_input(skb);
271
272drop:
273	kfree_skb(skb);
274	return -EINVAL;
275}
276
277/* decapsulate and forward inner L2 frame on specified interface */
278static int input_action_end_dx2(struct sk_buff *skb,
279				struct seg6_local_lwt *slwt)
280{
281	struct net *net = dev_net(skb->dev);
282	struct net_device *odev;
283	struct ethhdr *eth;
284
285	if (!decap_and_validate(skb, IPPROTO_ETHERNET))
286		goto drop;
287
288	if (!pskb_may_pull(skb, ETH_HLEN))
289		goto drop;
290
291	skb_reset_mac_header(skb);
292	eth = (struct ethhdr *)skb->data;
293
294	/* To determine the frame's protocol, we assume it is 802.3. This avoids
295	 * a call to eth_type_trans(), which is not really relevant for our
296	 * use case.
297	 */
298	if (!eth_proto_is_802_3(eth->h_proto))
299		goto drop;
300
301	odev = dev_get_by_index_rcu(net, slwt->oif);
302	if (!odev)
303		goto drop;
304
305	/* As we accept Ethernet frames, make sure the egress device is of
306	 * the correct type.
307	 */
308	if (odev->type != ARPHRD_ETHER)
309		goto drop;
310
311	if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
312		goto drop;
313
314	skb_orphan(skb);
315
316	if (skb_warn_if_lro(skb))
317		goto drop;
318
319	skb_forward_csum(skb);
320
321	if (skb->len - ETH_HLEN > odev->mtu)
322		goto drop;
323
324	skb->dev = odev;
325	skb->protocol = eth->h_proto;
326
327	return dev_queue_xmit(skb);
328
329drop:
330	kfree_skb(skb);
331	return -EINVAL;
332}
333
334/* decapsulate and forward to specified nexthop */
335static int input_action_end_dx6(struct sk_buff *skb,
336				struct seg6_local_lwt *slwt)
337{
338	struct in6_addr *nhaddr = NULL;
339
340	/* this function accepts IPv6 encapsulated packets, with either
341	 * an SRH with SL=0, or no SRH.
342	 */
343
344	if (!decap_and_validate(skb, IPPROTO_IPV6))
345		goto drop;
346
347	if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
348		goto drop;
349
350	/* The inner packet is not associated to any local interface,
351	 * so we do not call netif_rx().
352	 *
353	 * If slwt->nh6 is set to ::, then lookup the nexthop for the
354	 * inner packet's DA. Otherwise, use the specified nexthop.
355	 */
356
357	if (!ipv6_addr_any(&slwt->nh6))
358		nhaddr = &slwt->nh6;
359
360	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
361
362	seg6_lookup_nexthop(skb, nhaddr, 0);
363
364	return dst_input(skb);
365drop:
366	kfree_skb(skb);
367	return -EINVAL;
368}
369
370static int input_action_end_dx4(struct sk_buff *skb,
371				struct seg6_local_lwt *slwt)
372{
373	struct iphdr *iph;
374	__be32 nhaddr;
375	int err;
376
377	if (!decap_and_validate(skb, IPPROTO_IPIP))
378		goto drop;
379
380	if (!pskb_may_pull(skb, sizeof(struct iphdr)))
381		goto drop;
382
383	skb->protocol = htons(ETH_P_IP);
384
385	iph = ip_hdr(skb);
386
387	nhaddr = slwt->nh4.s_addr ?: iph->daddr;
388
389	skb_dst_drop(skb);
390
391	skb_set_transport_header(skb, sizeof(struct iphdr));
392
393	err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
394	if (err)
395		goto drop;
396
397	return dst_input(skb);
398
399drop:
400	kfree_skb(skb);
401	return -EINVAL;
402}
403
404static int input_action_end_dt6(struct sk_buff *skb,
405				struct seg6_local_lwt *slwt)
406{
407	if (!decap_and_validate(skb, IPPROTO_IPV6))
408		goto drop;
409
410	if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
411		goto drop;
412
413	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
414
415	seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
416
417	return dst_input(skb);
418
419drop:
420	kfree_skb(skb);
421	return -EINVAL;
422}
423
424/* push an SRH on top of the current one */
425static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
426{
427	struct ipv6_sr_hdr *srh;
428	int err = -EINVAL;
429
430	srh = get_and_validate_srh(skb);
431	if (!srh)
432		goto drop;
433
434	err = seg6_do_srh_inline(skb, slwt->srh);
435	if (err)
436		goto drop;
437
438	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
439
440	seg6_lookup_nexthop(skb, NULL, 0);
441
442	return dst_input(skb);
443
444drop:
445	kfree_skb(skb);
446	return err;
447}
448
449/* encapsulate within an outer IPv6 header and a specified SRH */
450static int input_action_end_b6_encap(struct sk_buff *skb,
451				     struct seg6_local_lwt *slwt)
452{
453	struct ipv6_sr_hdr *srh;
454	int err = -EINVAL;
455
456	srh = get_and_validate_srh(skb);
457	if (!srh)
458		goto drop;
459
460	advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
461
462	skb_reset_inner_headers(skb);
463	skb->encapsulation = 1;
464
465	err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
466	if (err)
467		goto drop;
468
469	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
470
471	seg6_lookup_nexthop(skb, NULL, 0);
472
473	return dst_input(skb);
474
475drop:
476	kfree_skb(skb);
477	return err;
478}
479
480DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
481
482bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
483{
484	struct seg6_bpf_srh_state *srh_state =
485		this_cpu_ptr(&seg6_bpf_srh_states);
486	struct ipv6_sr_hdr *srh = srh_state->srh;
487
488	if (unlikely(srh == NULL))
489		return false;
490
491	if (unlikely(!srh_state->valid)) {
492		if ((srh_state->hdrlen & 7) != 0)
493			return false;
494
495		srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
496		if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
497			return false;
498
499		srh_state->valid = true;
500	}
501
502	return true;
503}
504
505static int input_action_end_bpf(struct sk_buff *skb,
506				struct seg6_local_lwt *slwt)
507{
508	struct seg6_bpf_srh_state *srh_state =
509		this_cpu_ptr(&seg6_bpf_srh_states);
510	struct ipv6_sr_hdr *srh;
511	int ret;
512
513	srh = get_and_validate_srh(skb);
514	if (!srh) {
515		kfree_skb(skb);
516		return -EINVAL;
517	}
518	advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
519
520	/* preempt_disable is needed to protect the per-CPU buffer srh_state,
521	 * which is also accessed by the bpf_lwt_seg6_* helpers
522	 */
523	preempt_disable();
524	srh_state->srh = srh;
525	srh_state->hdrlen = srh->hdrlen << 3;
526	srh_state->valid = true;
527
528	rcu_read_lock();
529	bpf_compute_data_pointers(skb);
530	ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
531	rcu_read_unlock();
532
533	switch (ret) {
534	case BPF_OK:
535	case BPF_REDIRECT:
536		break;
537	case BPF_DROP:
538		goto drop;
539	default:
540		pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
541		goto drop;
542	}
543
544	if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
545		goto drop;
546
547	preempt_enable();
548	if (ret != BPF_REDIRECT)
549		seg6_lookup_nexthop(skb, NULL, 0);
550
551	return dst_input(skb);
552
553drop:
554	preempt_enable();
555	kfree_skb(skb);
556	return -EINVAL;
557}
558
559static struct seg6_action_desc seg6_action_table[] = {
560	{
561		.action		= SEG6_LOCAL_ACTION_END,
562		.attrs		= 0,
563		.input		= input_action_end,
564	},
565	{
566		.action		= SEG6_LOCAL_ACTION_END_X,
567		.attrs		= (1 << SEG6_LOCAL_NH6),
568		.input		= input_action_end_x,
569	},
570	{
571		.action		= SEG6_LOCAL_ACTION_END_T,
572		.attrs		= (1 << SEG6_LOCAL_TABLE),
573		.input		= input_action_end_t,
574	},
575	{
576		.action		= SEG6_LOCAL_ACTION_END_DX2,
577		.attrs		= (1 << SEG6_LOCAL_OIF),
578		.input		= input_action_end_dx2,
579	},
580	{
581		.action		= SEG6_LOCAL_ACTION_END_DX6,
582		.attrs		= (1 << SEG6_LOCAL_NH6),
583		.input		= input_action_end_dx6,
584	},
585	{
586		.action		= SEG6_LOCAL_ACTION_END_DX4,
587		.attrs		= (1 << SEG6_LOCAL_NH4),
588		.input		= input_action_end_dx4,
589	},
590	{
591		.action		= SEG6_LOCAL_ACTION_END_DT6,
592		.attrs		= (1 << SEG6_LOCAL_TABLE),
593		.input		= input_action_end_dt6,
594	},
595	{
596		.action		= SEG6_LOCAL_ACTION_END_B6,
597		.attrs		= (1 << SEG6_LOCAL_SRH),
598		.input		= input_action_end_b6,
599	},
600	{
601		.action		= SEG6_LOCAL_ACTION_END_B6_ENCAP,
602		.attrs		= (1 << SEG6_LOCAL_SRH),
603		.input		= input_action_end_b6_encap,
604		.static_headroom	= sizeof(struct ipv6hdr),
605	},
606	{
607		.action		= SEG6_LOCAL_ACTION_END_BPF,
608		.attrs		= (1 << SEG6_LOCAL_BPF),
609		.input		= input_action_end_bpf,
610	},
611
612};
613
614static struct seg6_action_desc *__get_action_desc(int action)
615{
616	struct seg6_action_desc *desc;
617	int i, count;
618
619	count = ARRAY_SIZE(seg6_action_table);
620	for (i = 0; i < count; i++) {
621		desc = &seg6_action_table[i];
622		if (desc->action == action)
623			return desc;
624	}
625
626	return NULL;
627}
628
629static int seg6_local_input(struct sk_buff *skb)
630{
631	struct dst_entry *orig_dst = skb_dst(skb);
632	struct seg6_action_desc *desc;
633	struct seg6_local_lwt *slwt;
634
635	if (skb->protocol != htons(ETH_P_IPV6)) {
636		kfree_skb(skb);
637		return -EINVAL;
638	}
639
640	slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
641	desc = slwt->desc;
642
643	return desc->input(skb, slwt);
644}
645
646static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
647	[SEG6_LOCAL_ACTION]	= { .type = NLA_U32 },
648	[SEG6_LOCAL_SRH]	= { .type = NLA_BINARY },
649	[SEG6_LOCAL_TABLE]	= { .type = NLA_U32 },
650	[SEG6_LOCAL_NH4]	= { .type = NLA_BINARY,
651				    .len = sizeof(struct in_addr) },
652	[SEG6_LOCAL_NH6]	= { .type = NLA_BINARY,
653				    .len = sizeof(struct in6_addr) },
654	[SEG6_LOCAL_IIF]	= { .type = NLA_U32 },
655	[SEG6_LOCAL_OIF]	= { .type = NLA_U32 },
656	[SEG6_LOCAL_BPF]	= { .type = NLA_NESTED },
657};
658
659static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
660{
661	struct ipv6_sr_hdr *srh;
662	int len;
663
664	srh = nla_data(attrs[SEG6_LOCAL_SRH]);
665	len = nla_len(attrs[SEG6_LOCAL_SRH]);
666
667	/* SRH must contain at least one segment */
668	if (len < sizeof(*srh) + sizeof(struct in6_addr))
669		return -EINVAL;
670
671	if (!seg6_validate_srh(srh, len, false))
672		return -EINVAL;
673
674	slwt->srh = kmemdup(srh, len, GFP_KERNEL);
675	if (!slwt->srh)
676		return -ENOMEM;
677
678	slwt->headroom += len;
679
680	return 0;
681}
682
683static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
684{
685	struct ipv6_sr_hdr *srh;
686	struct nlattr *nla;
687	int len;
688
689	srh = slwt->srh;
690	len = (srh->hdrlen + 1) << 3;
691
692	nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
693	if (!nla)
694		return -EMSGSIZE;
695
696	memcpy(nla_data(nla), srh, len);
697
698	return 0;
699}
700
701static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
702{
703	int len = (a->srh->hdrlen + 1) << 3;
704
705	if (len != ((b->srh->hdrlen + 1) << 3))
706		return 1;
707
708	return memcmp(a->srh, b->srh, len);
709}
710
711static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
712{
713	slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
714
715	return 0;
716}
717
718static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
719{
720	if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
721		return -EMSGSIZE;
722
723	return 0;
724}
725
726static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
727{
728	if (a->table != b->table)
729		return 1;
730
731	return 0;
732}
733
734static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
735{
736	memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
737	       sizeof(struct in_addr));
738
739	return 0;
740}
741
742static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
743{
744	struct nlattr *nla;
745
746	nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
747	if (!nla)
748		return -EMSGSIZE;
749
750	memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
751
752	return 0;
753}
754
755static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
756{
757	return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
758}
759
760static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
761{
762	memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
763	       sizeof(struct in6_addr));
764
765	return 0;
766}
767
768static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
769{
770	struct nlattr *nla;
771
772	nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
773	if (!nla)
774		return -EMSGSIZE;
775
776	memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
777
778	return 0;
779}
780
781static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
782{
783	return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
784}
785
786static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
787{
788	slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
789
790	return 0;
791}
792
793static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
794{
795	if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
796		return -EMSGSIZE;
797
798	return 0;
799}
800
801static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
802{
803	if (a->iif != b->iif)
804		return 1;
805
806	return 0;
807}
808
809static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
810{
811	slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
812
813	return 0;
814}
815
816static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
817{
818	if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
819		return -EMSGSIZE;
820
821	return 0;
822}
823
824static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
825{
826	if (a->oif != b->oif)
827		return 1;
828
829	return 0;
830}
831
832#define MAX_PROG_NAME 256
833static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
834	[SEG6_LOCAL_BPF_PROG]	   = { .type = NLA_U32, },
835	[SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
836				       .len = MAX_PROG_NAME },
837};
838
839static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
840{
841	struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
842	struct bpf_prog *p;
843	int ret;
844	u32 fd;
845
846	ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
847					  attrs[SEG6_LOCAL_BPF],
848					  bpf_prog_policy, NULL);
849	if (ret < 0)
850		return ret;
851
852	if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
853		return -EINVAL;
854
855	slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
856	if (!slwt->bpf.name)
857		return -ENOMEM;
858
859	fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
860	p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
861	if (IS_ERR(p)) {
862		kfree(slwt->bpf.name);
863		return PTR_ERR(p);
864	}
865
866	slwt->bpf.prog = p;
867	return 0;
868}
869
870static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
871{
872	struct nlattr *nest;
873
874	if (!slwt->bpf.prog)
875		return 0;
876
877	nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
878	if (!nest)
879		return -EMSGSIZE;
880
881	if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
882		return -EMSGSIZE;
883
884	if (slwt->bpf.name &&
885	    nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
886		return -EMSGSIZE;
887
888	return nla_nest_end(skb, nest);
889}
890
891static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
892{
893	if (!a->bpf.name && !b->bpf.name)
894		return 0;
895
896	if (!a->bpf.name || !b->bpf.name)
897		return 1;
898
899	return strcmp(a->bpf.name, b->bpf.name);
900}
901
902struct seg6_action_param {
903	int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
904	int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
905	int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
906};
907
908static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
909	[SEG6_LOCAL_SRH]	= { .parse = parse_nla_srh,
910				    .put = put_nla_srh,
911				    .cmp = cmp_nla_srh },
912
913	[SEG6_LOCAL_TABLE]	= { .parse = parse_nla_table,
914				    .put = put_nla_table,
915				    .cmp = cmp_nla_table },
916
917	[SEG6_LOCAL_NH4]	= { .parse = parse_nla_nh4,
918				    .put = put_nla_nh4,
919				    .cmp = cmp_nla_nh4 },
920
921	[SEG6_LOCAL_NH6]	= { .parse = parse_nla_nh6,
922				    .put = put_nla_nh6,
923				    .cmp = cmp_nla_nh6 },
924
925	[SEG6_LOCAL_IIF]	= { .parse = parse_nla_iif,
926				    .put = put_nla_iif,
927				    .cmp = cmp_nla_iif },
928
929	[SEG6_LOCAL_OIF]	= { .parse = parse_nla_oif,
930				    .put = put_nla_oif,
931				    .cmp = cmp_nla_oif },
932
933	[SEG6_LOCAL_BPF]	= { .parse = parse_nla_bpf,
934				    .put = put_nla_bpf,
935				    .cmp = cmp_nla_bpf },
936
937};
938
939static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
940{
941	struct seg6_action_param *param;
942	struct seg6_action_desc *desc;
943	int i, err;
944
945	desc = __get_action_desc(slwt->action);
946	if (!desc)
947		return -EINVAL;
948
949	if (!desc->input)
950		return -EOPNOTSUPP;
951
952	slwt->desc = desc;
953	slwt->headroom += desc->static_headroom;
954
955	for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
956		if (desc->attrs & (1 << i)) {
957			if (!attrs[i])
958				return -EINVAL;
959
960			param = &seg6_action_params[i];
961
962			err = param->parse(attrs, slwt);
963			if (err < 0)
964				return err;
965		}
966	}
967
968	return 0;
969}
970
971static int seg6_local_build_state(struct net *net, struct nlattr *nla,
972				  unsigned int family, const void *cfg,
973				  struct lwtunnel_state **ts,
974				  struct netlink_ext_ack *extack)
975{
976	struct nlattr *tb[SEG6_LOCAL_MAX + 1];
977	struct lwtunnel_state *newts;
978	struct seg6_local_lwt *slwt;
979	int err;
980
981	if (family != AF_INET6)
982		return -EINVAL;
983
984	err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
985					  seg6_local_policy, extack);
986
987	if (err < 0)
988		return err;
989
990	if (!tb[SEG6_LOCAL_ACTION])
991		return -EINVAL;
992
993	newts = lwtunnel_state_alloc(sizeof(*slwt));
994	if (!newts)
995		return -ENOMEM;
996
997	slwt = seg6_local_lwtunnel(newts);
998	slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
999
1000	err = parse_nla_action(tb, slwt);
1001	if (err < 0)
1002		goto out_free;
1003
1004	newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
1005	newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
1006	newts->headroom = slwt->headroom;
1007
1008	*ts = newts;
1009
1010	return 0;
1011
1012out_free:
1013	kfree(slwt->srh);
1014	kfree(newts);
1015	return err;
1016}
1017
1018static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
1019{
1020	struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1021
1022	kfree(slwt->srh);
1023
1024	if (slwt->desc->attrs & (1 << SEG6_LOCAL_BPF)) {
1025		kfree(slwt->bpf.name);
1026		bpf_prog_put(slwt->bpf.prog);
1027	}
1028
1029	return;
1030}
1031
1032static int seg6_local_fill_encap(struct sk_buff *skb,
1033				 struct lwtunnel_state *lwt)
1034{
1035	struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1036	struct seg6_action_param *param;
1037	int i, err;
1038
1039	if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1040		return -EMSGSIZE;
1041
1042	for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1043		if (slwt->desc->attrs & (1 << i)) {
1044			param = &seg6_action_params[i];
1045			err = param->put(skb, slwt);
1046			if (err < 0)
1047				return err;
1048		}
1049	}
1050
1051	return 0;
1052}
1053
1054static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1055{
1056	struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1057	unsigned long attrs;
1058	int nlsize;
1059
1060	nlsize = nla_total_size(4); /* action */
1061
1062	attrs = slwt->desc->attrs;
1063
1064	if (attrs & (1 << SEG6_LOCAL_SRH))
1065		nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1066
1067	if (attrs & (1 << SEG6_LOCAL_TABLE))
1068		nlsize += nla_total_size(4);
1069
1070	if (attrs & (1 << SEG6_LOCAL_NH4))
1071		nlsize += nla_total_size(4);
1072
1073	if (attrs & (1 << SEG6_LOCAL_NH6))
1074		nlsize += nla_total_size(16);
1075
1076	if (attrs & (1 << SEG6_LOCAL_IIF))
1077		nlsize += nla_total_size(4);
1078
1079	if (attrs & (1 << SEG6_LOCAL_OIF))
1080		nlsize += nla_total_size(4);
1081
1082	if (attrs & (1 << SEG6_LOCAL_BPF))
1083		nlsize += nla_total_size(sizeof(struct nlattr)) +
1084		       nla_total_size(MAX_PROG_NAME) +
1085		       nla_total_size(4);
1086
1087	return nlsize;
1088}
1089
1090static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1091				struct lwtunnel_state *b)
1092{
1093	struct seg6_local_lwt *slwt_a, *slwt_b;
1094	struct seg6_action_param *param;
1095	int i;
1096
1097	slwt_a = seg6_local_lwtunnel(a);
1098	slwt_b = seg6_local_lwtunnel(b);
1099
1100	if (slwt_a->action != slwt_b->action)
1101		return 1;
1102
1103	if (slwt_a->desc->attrs != slwt_b->desc->attrs)
1104		return 1;
1105
1106	for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1107		if (slwt_a->desc->attrs & (1 << i)) {
1108			param = &seg6_action_params[i];
1109			if (param->cmp(slwt_a, slwt_b))
1110				return 1;
1111		}
1112	}
1113
1114	return 0;
1115}
1116
1117static const struct lwtunnel_encap_ops seg6_local_ops = {
1118	.build_state	= seg6_local_build_state,
1119	.destroy_state	= seg6_local_destroy_state,
1120	.input		= seg6_local_input,
1121	.fill_encap	= seg6_local_fill_encap,
1122	.get_encap_size	= seg6_local_get_encap_size,
1123	.cmp_encap	= seg6_local_cmp_encap,
1124	.owner		= THIS_MODULE,
1125};
1126
1127int __init seg6_local_init(void)
1128{
1129	return lwtunnel_encap_add_ops(&seg6_local_ops,
1130				      LWTUNNEL_ENCAP_SEG6_LOCAL);
1131}
1132
1133void seg6_local_exit(void)
1134{
1135	lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1136}
1137