1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 *  SR-IPv6 implementation
4 *
5 *  Author:
6 *  David Lebrun <david.lebrun@uclouvain.be>
7 */
8
9#include <linux/types.h>
10#include <linux/skbuff.h>
11#include <linux/net.h>
12#include <linux/module.h>
13#include <net/ip.h>
14#include <net/ip_tunnels.h>
15#include <net/lwtunnel.h>
16#include <net/netevent.h>
17#include <net/netns/generic.h>
18#include <net/ip6_fib.h>
19#include <net/route.h>
20#include <net/seg6.h>
21#include <linux/seg6.h>
22#include <linux/seg6_iptunnel.h>
23#include <net/addrconf.h>
24#include <net/ip6_route.h>
25#include <net/dst_cache.h>
26#ifdef CONFIG_IPV6_SEG6_HMAC
27#include <net/seg6_hmac.h>
28#endif
29
30static size_t seg6_lwt_headroom(struct seg6_iptunnel_encap *tuninfo)
31{
32	int head = 0;
33
34	switch (tuninfo->mode) {
35	case SEG6_IPTUN_MODE_INLINE:
36		break;
37	case SEG6_IPTUN_MODE_ENCAP:
38		head = sizeof(struct ipv6hdr);
39		break;
40	case SEG6_IPTUN_MODE_L2ENCAP:
41		return 0;
42	}
43
44	return ((tuninfo->srh->hdrlen + 1) << 3) + head;
45}
46
47struct seg6_lwt {
48	struct dst_cache cache;
49	struct seg6_iptunnel_encap tuninfo[];
50};
51
52static inline struct seg6_lwt *seg6_lwt_lwtunnel(struct lwtunnel_state *lwt)
53{
54	return (struct seg6_lwt *)lwt->data;
55}
56
57static inline struct seg6_iptunnel_encap *
58seg6_encap_lwtunnel(struct lwtunnel_state *lwt)
59{
60	return seg6_lwt_lwtunnel(lwt)->tuninfo;
61}
62
63static const struct nla_policy seg6_iptunnel_policy[SEG6_IPTUNNEL_MAX + 1] = {
64	[SEG6_IPTUNNEL_SRH]	= { .type = NLA_BINARY },
65};
66
67static int nla_put_srh(struct sk_buff *skb, int attrtype,
68		       struct seg6_iptunnel_encap *tuninfo)
69{
70	struct seg6_iptunnel_encap *data;
71	struct nlattr *nla;
72	int len;
73
74	len = SEG6_IPTUN_ENCAP_SIZE(tuninfo);
75
76	nla = nla_reserve(skb, attrtype, len);
77	if (!nla)
78		return -EMSGSIZE;
79
80	data = nla_data(nla);
81	memcpy(data, tuninfo, len);
82
83	return 0;
84}
85
86static void set_tun_src(struct net *net, struct net_device *dev,
87			struct in6_addr *daddr, struct in6_addr *saddr)
88{
89	struct seg6_pernet_data *sdata = seg6_pernet(net);
90	struct in6_addr *tun_src;
91
92	rcu_read_lock();
93
94	tun_src = rcu_dereference(sdata->tun_src);
95
96	if (!ipv6_addr_any(tun_src)) {
97		memcpy(saddr, tun_src, sizeof(struct in6_addr));
98	} else {
99		ipv6_dev_get_saddr(net, dev, daddr, IPV6_PREFER_SRC_PUBLIC,
100				   saddr);
101	}
102
103	rcu_read_unlock();
104}
105
106/* Compute flowlabel for outer IPv6 header */
107static __be32 seg6_make_flowlabel(struct net *net, struct sk_buff *skb,
108				  struct ipv6hdr *inner_hdr)
109{
110	int do_flowlabel = net->ipv6.sysctl.seg6_flowlabel;
111	__be32 flowlabel = 0;
112	u32 hash;
113
114	if (do_flowlabel > 0) {
115		hash = skb_get_hash(skb);
116		hash = rol32(hash, 16);
117		flowlabel = (__force __be32)hash & IPV6_FLOWLABEL_MASK;
118	} else if (!do_flowlabel && skb->protocol == htons(ETH_P_IPV6)) {
119		flowlabel = ip6_flowlabel(inner_hdr);
120	}
121	return flowlabel;
122}
123
124/* encapsulate an IPv6 packet within an outer IPv6 header with a given SRH */
125int seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh, int proto)
126{
127	struct dst_entry *dst = skb_dst(skb);
128	struct net *net = dev_net(dst->dev);
129	struct ipv6hdr *hdr, *inner_hdr;
130	struct ipv6_sr_hdr *isrh;
131	int hdrlen, tot_len, err;
132	__be32 flowlabel;
133
134	hdrlen = (osrh->hdrlen + 1) << 3;
135	tot_len = hdrlen + sizeof(*hdr);
136
137	err = skb_cow_head(skb, tot_len + skb->mac_len);
138	if (unlikely(err))
139		return err;
140
141	inner_hdr = ipv6_hdr(skb);
142	flowlabel = seg6_make_flowlabel(net, skb, inner_hdr);
143
144	skb_push(skb, tot_len);
145	skb_reset_network_header(skb);
146	skb_mac_header_rebuild(skb);
147	hdr = ipv6_hdr(skb);
148
149	/* inherit tc, flowlabel and hlim
150	 * hlim will be decremented in ip6_forward() afterwards and
151	 * decapsulation will overwrite inner hlim with outer hlim
152	 */
153
154	if (skb->protocol == htons(ETH_P_IPV6)) {
155		ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)),
156			     flowlabel);
157		hdr->hop_limit = inner_hdr->hop_limit;
158	} else {
159		ip6_flow_hdr(hdr, 0, flowlabel);
160		hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb));
161
162		memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
163
164		/* the control block has been erased, so we have to set the
165		 * iif once again.
166		 * We read the receiving interface index directly from the
167		 * skb->skb_iif as it is done in the IPv4 receiving path (i.e.:
168		 * ip_rcv_core(...)).
169		 */
170		IP6CB(skb)->iif = skb->skb_iif;
171	}
172
173	hdr->nexthdr = NEXTHDR_ROUTING;
174
175	isrh = (void *)hdr + sizeof(*hdr);
176	memcpy(isrh, osrh, hdrlen);
177
178	isrh->nexthdr = proto;
179
180	hdr->daddr = isrh->segments[isrh->first_segment];
181	set_tun_src(net, dst->dev, &hdr->daddr, &hdr->saddr);
182
183#ifdef CONFIG_IPV6_SEG6_HMAC
184	if (sr_has_hmac(isrh)) {
185		err = seg6_push_hmac(net, &hdr->saddr, isrh);
186		if (unlikely(err))
187			return err;
188	}
189#endif
190
191	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
192
193	skb_postpush_rcsum(skb, hdr, tot_len);
194
195	return 0;
196}
197EXPORT_SYMBOL_GPL(seg6_do_srh_encap);
198
199/* insert an SRH within an IPv6 packet, just after the IPv6 header */
200int seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh)
201{
202	struct ipv6hdr *hdr, *oldhdr;
203	struct ipv6_sr_hdr *isrh;
204	int hdrlen, err;
205
206	hdrlen = (osrh->hdrlen + 1) << 3;
207
208	err = skb_cow_head(skb, hdrlen + skb->mac_len);
209	if (unlikely(err))
210		return err;
211
212	oldhdr = ipv6_hdr(skb);
213
214	skb_pull(skb, sizeof(struct ipv6hdr));
215	skb_postpull_rcsum(skb, skb_network_header(skb),
216			   sizeof(struct ipv6hdr));
217
218	skb_push(skb, sizeof(struct ipv6hdr) + hdrlen);
219	skb_reset_network_header(skb);
220	skb_mac_header_rebuild(skb);
221
222	hdr = ipv6_hdr(skb);
223
224	memmove(hdr, oldhdr, sizeof(*hdr));
225
226	isrh = (void *)hdr + sizeof(*hdr);
227	memcpy(isrh, osrh, hdrlen);
228
229	isrh->nexthdr = hdr->nexthdr;
230	hdr->nexthdr = NEXTHDR_ROUTING;
231
232	isrh->segments[0] = hdr->daddr;
233	hdr->daddr = isrh->segments[isrh->first_segment];
234
235#ifdef CONFIG_IPV6_SEG6_HMAC
236	if (sr_has_hmac(isrh)) {
237		struct net *net = dev_net(skb_dst(skb)->dev);
238
239		err = seg6_push_hmac(net, &hdr->saddr, isrh);
240		if (unlikely(err))
241			return err;
242	}
243#endif
244
245	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
246
247	skb_postpush_rcsum(skb, hdr, sizeof(struct ipv6hdr) + hdrlen);
248
249	return 0;
250}
251EXPORT_SYMBOL_GPL(seg6_do_srh_inline);
252
253static int seg6_do_srh(struct sk_buff *skb)
254{
255	struct dst_entry *dst = skb_dst(skb);
256	struct seg6_iptunnel_encap *tinfo;
257	int proto, err = 0;
258
259	tinfo = seg6_encap_lwtunnel(dst->lwtstate);
260
261	switch (tinfo->mode) {
262	case SEG6_IPTUN_MODE_INLINE:
263		if (skb->protocol != htons(ETH_P_IPV6))
264			return -EINVAL;
265
266		err = seg6_do_srh_inline(skb, tinfo->srh);
267		if (err)
268			return err;
269		break;
270	case SEG6_IPTUN_MODE_ENCAP:
271		err = iptunnel_handle_offloads(skb, SKB_GSO_IPXIP6);
272		if (err)
273			return err;
274
275		if (skb->protocol == htons(ETH_P_IPV6))
276			proto = IPPROTO_IPV6;
277		else if (skb->protocol == htons(ETH_P_IP))
278			proto = IPPROTO_IPIP;
279		else
280			return -EINVAL;
281
282		err = seg6_do_srh_encap(skb, tinfo->srh, proto);
283		if (err)
284			return err;
285
286		skb_set_inner_transport_header(skb, skb_transport_offset(skb));
287		skb_set_inner_protocol(skb, skb->protocol);
288		skb->protocol = htons(ETH_P_IPV6);
289		break;
290	case SEG6_IPTUN_MODE_L2ENCAP:
291		if (!skb_mac_header_was_set(skb))
292			return -EINVAL;
293
294		if (pskb_expand_head(skb, skb->mac_len, 0, GFP_ATOMIC) < 0)
295			return -ENOMEM;
296
297		skb_mac_header_rebuild(skb);
298		skb_push(skb, skb->mac_len);
299
300		err = seg6_do_srh_encap(skb, tinfo->srh, IPPROTO_ETHERNET);
301		if (err)
302			return err;
303
304		skb->protocol = htons(ETH_P_IPV6);
305		break;
306	}
307
308	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
309
310	return 0;
311}
312
313static int seg6_input(struct sk_buff *skb)
314{
315	struct dst_entry *orig_dst = skb_dst(skb);
316	struct dst_entry *dst = NULL;
317	struct seg6_lwt *slwt;
318	int err;
319
320	err = seg6_do_srh(skb);
321	if (unlikely(err)) {
322		kfree_skb(skb);
323		return err;
324	}
325
326	slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate);
327
328	preempt_disable();
329	dst = dst_cache_get(&slwt->cache);
330	preempt_enable();
331
332	skb_dst_drop(skb);
333
334	if (!dst) {
335		ip6_route_input(skb);
336		dst = skb_dst(skb);
337		if (!dst->error) {
338			preempt_disable();
339			dst_cache_set_ip6(&slwt->cache, dst,
340					  &ipv6_hdr(skb)->saddr);
341			preempt_enable();
342		}
343	} else {
344		skb_dst_set(skb, dst);
345	}
346
347	err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev));
348	if (unlikely(err))
349		return err;
350
351	return dst_input(skb);
352}
353
354static int seg6_output(struct net *net, struct sock *sk, struct sk_buff *skb)
355{
356	struct dst_entry *orig_dst = skb_dst(skb);
357	struct dst_entry *dst = NULL;
358	struct seg6_lwt *slwt;
359	int err = -EINVAL;
360
361	err = seg6_do_srh(skb);
362	if (unlikely(err))
363		goto drop;
364
365	slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate);
366
367	preempt_disable();
368	dst = dst_cache_get(&slwt->cache);
369	preempt_enable();
370
371	if (unlikely(!dst)) {
372		struct ipv6hdr *hdr = ipv6_hdr(skb);
373		struct flowi6 fl6;
374
375		memset(&fl6, 0, sizeof(fl6));
376		fl6.daddr = hdr->daddr;
377		fl6.saddr = hdr->saddr;
378		fl6.flowlabel = ip6_flowinfo(hdr);
379		fl6.flowi6_mark = skb->mark;
380		fl6.flowi6_proto = hdr->nexthdr;
381
382		dst = ip6_route_output(net, NULL, &fl6);
383		if (dst->error) {
384			err = dst->error;
385			dst_release(dst);
386			goto drop;
387		}
388
389		preempt_disable();
390		dst_cache_set_ip6(&slwt->cache, dst, &fl6.saddr);
391		preempt_enable();
392	}
393
394	skb_dst_drop(skb);
395	skb_dst_set(skb, dst);
396
397	err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev));
398	if (unlikely(err))
399		goto drop;
400
401	return dst_output(net, sk, skb);
402drop:
403	kfree_skb(skb);
404	return err;
405}
406
407static int seg6_build_state(struct net *net, struct nlattr *nla,
408			    unsigned int family, const void *cfg,
409			    struct lwtunnel_state **ts,
410			    struct netlink_ext_ack *extack)
411{
412	struct nlattr *tb[SEG6_IPTUNNEL_MAX + 1];
413	struct seg6_iptunnel_encap *tuninfo;
414	struct lwtunnel_state *newts;
415	int tuninfo_len, min_size;
416	struct seg6_lwt *slwt;
417	int err;
418
419	if (family != AF_INET && family != AF_INET6)
420		return -EINVAL;
421
422	err = nla_parse_nested_deprecated(tb, SEG6_IPTUNNEL_MAX, nla,
423					  seg6_iptunnel_policy, extack);
424
425	if (err < 0)
426		return err;
427
428	if (!tb[SEG6_IPTUNNEL_SRH])
429		return -EINVAL;
430
431	tuninfo = nla_data(tb[SEG6_IPTUNNEL_SRH]);
432	tuninfo_len = nla_len(tb[SEG6_IPTUNNEL_SRH]);
433
434	/* tuninfo must contain at least the iptunnel encap structure,
435	 * the SRH and one segment
436	 */
437	min_size = sizeof(*tuninfo) + sizeof(struct ipv6_sr_hdr) +
438		   sizeof(struct in6_addr);
439	if (tuninfo_len < min_size)
440		return -EINVAL;
441
442	switch (tuninfo->mode) {
443	case SEG6_IPTUN_MODE_INLINE:
444		if (family != AF_INET6)
445			return -EINVAL;
446
447		break;
448	case SEG6_IPTUN_MODE_ENCAP:
449		break;
450	case SEG6_IPTUN_MODE_L2ENCAP:
451		break;
452	default:
453		return -EINVAL;
454	}
455
456	/* verify that SRH is consistent */
457	if (!seg6_validate_srh(tuninfo->srh, tuninfo_len - sizeof(*tuninfo), false))
458		return -EINVAL;
459
460	newts = lwtunnel_state_alloc(tuninfo_len + sizeof(*slwt));
461	if (!newts)
462		return -ENOMEM;
463
464	slwt = seg6_lwt_lwtunnel(newts);
465
466	err = dst_cache_init(&slwt->cache, GFP_ATOMIC);
467	if (err) {
468		kfree(newts);
469		return err;
470	}
471
472	memcpy(&slwt->tuninfo, tuninfo, tuninfo_len);
473
474	newts->type = LWTUNNEL_ENCAP_SEG6;
475	newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
476
477	if (tuninfo->mode != SEG6_IPTUN_MODE_L2ENCAP)
478		newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
479
480	newts->headroom = seg6_lwt_headroom(tuninfo);
481
482	*ts = newts;
483
484	return 0;
485}
486
487static void seg6_destroy_state(struct lwtunnel_state *lwt)
488{
489	dst_cache_destroy(&seg6_lwt_lwtunnel(lwt)->cache);
490}
491
492static int seg6_fill_encap_info(struct sk_buff *skb,
493				struct lwtunnel_state *lwtstate)
494{
495	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
496
497	if (nla_put_srh(skb, SEG6_IPTUNNEL_SRH, tuninfo))
498		return -EMSGSIZE;
499
500	return 0;
501}
502
503static int seg6_encap_nlsize(struct lwtunnel_state *lwtstate)
504{
505	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
506
507	return nla_total_size(SEG6_IPTUN_ENCAP_SIZE(tuninfo));
508}
509
510static int seg6_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
511{
512	struct seg6_iptunnel_encap *a_hdr = seg6_encap_lwtunnel(a);
513	struct seg6_iptunnel_encap *b_hdr = seg6_encap_lwtunnel(b);
514	int len = SEG6_IPTUN_ENCAP_SIZE(a_hdr);
515
516	if (len != SEG6_IPTUN_ENCAP_SIZE(b_hdr))
517		return 1;
518
519	return memcmp(a_hdr, b_hdr, len);
520}
521
522static const struct lwtunnel_encap_ops seg6_iptun_ops = {
523	.build_state = seg6_build_state,
524	.destroy_state = seg6_destroy_state,
525	.output = seg6_output,
526	.input = seg6_input,
527	.fill_encap = seg6_fill_encap_info,
528	.get_encap_size = seg6_encap_nlsize,
529	.cmp_encap = seg6_encap_cmp,
530	.owner = THIS_MODULE,
531};
532
533int __init seg6_iptunnel_init(void)
534{
535	return lwtunnel_encap_add_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
536}
537
538void seg6_iptunnel_exit(void)
539{
540	lwtunnel_encap_del_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
541}
542