xref: /kernel/linux/linux-5.10/net/mpls/af_mpls.c (revision 8c2ecf20)
1// SPDX-License-Identifier: GPL-2.0-only
2#include <linux/types.h>
3#include <linux/skbuff.h>
4#include <linux/socket.h>
5#include <linux/sysctl.h>
6#include <linux/net.h>
7#include <linux/module.h>
8#include <linux/if_arp.h>
9#include <linux/ipv6.h>
10#include <linux/mpls.h>
11#include <linux/netconf.h>
12#include <linux/nospec.h>
13#include <linux/vmalloc.h>
14#include <linux/percpu.h>
15#include <net/ip.h>
16#include <net/dst.h>
17#include <net/sock.h>
18#include <net/arp.h>
19#include <net/ip_fib.h>
20#include <net/netevent.h>
21#include <net/ip_tunnels.h>
22#include <net/netns/generic.h>
23#if IS_ENABLED(CONFIG_IPV6)
24#include <net/ipv6.h>
25#endif
26#include <net/ipv6_stubs.h>
27#include <net/rtnh.h>
28#include "internal.h"
29
30/* max memory we will use for mpls_route */
31#define MAX_MPLS_ROUTE_MEM	4096
32
33/* Maximum number of labels to look ahead at when selecting a path of
34 * a multipath route
35 */
36#define MAX_MP_SELECT_LABELS 4
37
38#define MPLS_NEIGH_TABLE_UNSPEC (NEIGH_LINK_TABLE + 1)
39
40static int label_limit = (1 << 20) - 1;
41static int ttl_max = 255;
42
43#if IS_ENABLED(CONFIG_NET_IP_TUNNEL)
44static size_t ipgre_mpls_encap_hlen(struct ip_tunnel_encap *e)
45{
46	return sizeof(struct mpls_shim_hdr);
47}
48
49static const struct ip_tunnel_encap_ops mpls_iptun_ops = {
50	.encap_hlen	= ipgre_mpls_encap_hlen,
51};
52
53static int ipgre_tunnel_encap_add_mpls_ops(void)
54{
55	return ip_tunnel_encap_add_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS);
56}
57
58static void ipgre_tunnel_encap_del_mpls_ops(void)
59{
60	ip_tunnel_encap_del_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS);
61}
62#else
63static int ipgre_tunnel_encap_add_mpls_ops(void)
64{
65	return 0;
66}
67
68static void ipgre_tunnel_encap_del_mpls_ops(void)
69{
70}
71#endif
72
73static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
74		       struct nlmsghdr *nlh, struct net *net, u32 portid,
75		       unsigned int nlm_flags);
76
77static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
78{
79	struct mpls_route *rt = NULL;
80
81	if (index < net->mpls.platform_labels) {
82		struct mpls_route __rcu **platform_label =
83			rcu_dereference(net->mpls.platform_label);
84		rt = rcu_dereference(platform_label[index]);
85	}
86	return rt;
87}
88
89bool mpls_output_possible(const struct net_device *dev)
90{
91	return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
92}
93EXPORT_SYMBOL_GPL(mpls_output_possible);
94
95static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh)
96{
97	return (u8 *)nh + rt->rt_via_offset;
98}
99
100static const u8 *mpls_nh_via(const struct mpls_route *rt,
101			     const struct mpls_nh *nh)
102{
103	return __mpls_nh_via((struct mpls_route *)rt, (struct mpls_nh *)nh);
104}
105
106static unsigned int mpls_nh_header_size(const struct mpls_nh *nh)
107{
108	/* The size of the layer 2.5 labels to be added for this route */
109	return nh->nh_labels * sizeof(struct mpls_shim_hdr);
110}
111
112unsigned int mpls_dev_mtu(const struct net_device *dev)
113{
114	/* The amount of data the layer 2 frame can hold */
115	return dev->mtu;
116}
117EXPORT_SYMBOL_GPL(mpls_dev_mtu);
118
119bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
120{
121	if (skb->len <= mtu)
122		return false;
123
124	if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu))
125		return false;
126
127	return true;
128}
129EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
130
131void mpls_stats_inc_outucastpkts(struct net_device *dev,
132				 const struct sk_buff *skb)
133{
134	struct mpls_dev *mdev;
135
136	if (skb->protocol == htons(ETH_P_MPLS_UC)) {
137		mdev = mpls_dev_get(dev);
138		if (mdev)
139			MPLS_INC_STATS_LEN(mdev, skb->len,
140					   tx_packets,
141					   tx_bytes);
142	} else if (skb->protocol == htons(ETH_P_IP)) {
143		IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len);
144#if IS_ENABLED(CONFIG_IPV6)
145	} else if (skb->protocol == htons(ETH_P_IPV6)) {
146		struct inet6_dev *in6dev = __in6_dev_get(dev);
147
148		if (in6dev)
149			IP6_UPD_PO_STATS(dev_net(dev), in6dev,
150					 IPSTATS_MIB_OUT, skb->len);
151#endif
152	}
153}
154EXPORT_SYMBOL_GPL(mpls_stats_inc_outucastpkts);
155
156static u32 mpls_multipath_hash(struct mpls_route *rt, struct sk_buff *skb)
157{
158	struct mpls_entry_decoded dec;
159	unsigned int mpls_hdr_len = 0;
160	struct mpls_shim_hdr *hdr;
161	bool eli_seen = false;
162	int label_index;
163	u32 hash = 0;
164
165	for (label_index = 0; label_index < MAX_MP_SELECT_LABELS;
166	     label_index++) {
167		mpls_hdr_len += sizeof(*hdr);
168		if (!pskb_may_pull(skb, mpls_hdr_len))
169			break;
170
171		/* Read and decode the current label */
172		hdr = mpls_hdr(skb) + label_index;
173		dec = mpls_entry_decode(hdr);
174
175		/* RFC6790 - reserved labels MUST NOT be used as keys
176		 * for the load-balancing function
177		 */
178		if (likely(dec.label >= MPLS_LABEL_FIRST_UNRESERVED)) {
179			hash = jhash_1word(dec.label, hash);
180
181			/* The entropy label follows the entropy label
182			 * indicator, so this means that the entropy
183			 * label was just added to the hash - no need to
184			 * go any deeper either in the label stack or in the
185			 * payload
186			 */
187			if (eli_seen)
188				break;
189		} else if (dec.label == MPLS_LABEL_ENTROPY) {
190			eli_seen = true;
191		}
192
193		if (!dec.bos)
194			continue;
195
196		/* found bottom label; does skb have room for a header? */
197		if (pskb_may_pull(skb, mpls_hdr_len + sizeof(struct iphdr))) {
198			const struct iphdr *v4hdr;
199
200			v4hdr = (const struct iphdr *)(hdr + 1);
201			if (v4hdr->version == 4) {
202				hash = jhash_3words(ntohl(v4hdr->saddr),
203						    ntohl(v4hdr->daddr),
204						    v4hdr->protocol, hash);
205			} else if (v4hdr->version == 6 &&
206				   pskb_may_pull(skb, mpls_hdr_len +
207						 sizeof(struct ipv6hdr))) {
208				const struct ipv6hdr *v6hdr;
209
210				v6hdr = (const struct ipv6hdr *)(hdr + 1);
211				hash = __ipv6_addr_jhash(&v6hdr->saddr, hash);
212				hash = __ipv6_addr_jhash(&v6hdr->daddr, hash);
213				hash = jhash_1word(v6hdr->nexthdr, hash);
214			}
215		}
216
217		break;
218	}
219
220	return hash;
221}
222
223static struct mpls_nh *mpls_get_nexthop(struct mpls_route *rt, u8 index)
224{
225	return (struct mpls_nh *)((u8 *)rt->rt_nh + index * rt->rt_nh_size);
226}
227
228/* number of alive nexthops (rt->rt_nhn_alive) and the flags for
229 * a next hop (nh->nh_flags) are modified by netdev event handlers.
230 * Since those fields can change at any moment, use READ_ONCE to
231 * access both.
232 */
233static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
234					     struct sk_buff *skb)
235{
236	u32 hash = 0;
237	int nh_index = 0;
238	int n = 0;
239	u8 alive;
240
241	/* No need to look further into packet if there's only
242	 * one path
243	 */
244	if (rt->rt_nhn == 1)
245		return rt->rt_nh;
246
247	alive = READ_ONCE(rt->rt_nhn_alive);
248	if (alive == 0)
249		return NULL;
250
251	hash = mpls_multipath_hash(rt, skb);
252	nh_index = hash % alive;
253	if (alive == rt->rt_nhn)
254		goto out;
255	for_nexthops(rt) {
256		unsigned int nh_flags = READ_ONCE(nh->nh_flags);
257
258		if (nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
259			continue;
260		if (n == nh_index)
261			return nh;
262		n++;
263	} endfor_nexthops(rt);
264
265out:
266	return mpls_get_nexthop(rt, nh_index);
267}
268
269static bool mpls_egress(struct net *net, struct mpls_route *rt,
270			struct sk_buff *skb, struct mpls_entry_decoded dec)
271{
272	enum mpls_payload_type payload_type;
273	bool success = false;
274
275	/* The IPv4 code below accesses through the IPv4 header
276	 * checksum, which is 12 bytes into the packet.
277	 * The IPv6 code below accesses through the IPv6 hop limit
278	 * which is 8 bytes into the packet.
279	 *
280	 * For all supported cases there should always be at least 12
281	 * bytes of packet data present.  The IPv4 header is 20 bytes
282	 * without options and the IPv6 header is always 40 bytes
283	 * long.
284	 */
285	if (!pskb_may_pull(skb, 12))
286		return false;
287
288	payload_type = rt->rt_payload_type;
289	if (payload_type == MPT_UNSPEC)
290		payload_type = ip_hdr(skb)->version;
291
292	switch (payload_type) {
293	case MPT_IPV4: {
294		struct iphdr *hdr4 = ip_hdr(skb);
295		u8 new_ttl;
296		skb->protocol = htons(ETH_P_IP);
297
298		/* If propagating TTL, take the decremented TTL from
299		 * the incoming MPLS header, otherwise decrement the
300		 * TTL, but only if not 0 to avoid underflow.
301		 */
302		if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED ||
303		    (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT &&
304		     net->mpls.ip_ttl_propagate))
305			new_ttl = dec.ttl;
306		else
307			new_ttl = hdr4->ttl ? hdr4->ttl - 1 : 0;
308
309		csum_replace2(&hdr4->check,
310			      htons(hdr4->ttl << 8),
311			      htons(new_ttl << 8));
312		hdr4->ttl = new_ttl;
313		success = true;
314		break;
315	}
316	case MPT_IPV6: {
317		struct ipv6hdr *hdr6 = ipv6_hdr(skb);
318		skb->protocol = htons(ETH_P_IPV6);
319
320		/* If propagating TTL, take the decremented TTL from
321		 * the incoming MPLS header, otherwise decrement the
322		 * hop limit, but only if not 0 to avoid underflow.
323		 */
324		if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED ||
325		    (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT &&
326		     net->mpls.ip_ttl_propagate))
327			hdr6->hop_limit = dec.ttl;
328		else if (hdr6->hop_limit)
329			hdr6->hop_limit = hdr6->hop_limit - 1;
330		success = true;
331		break;
332	}
333	case MPT_UNSPEC:
334		/* Should have decided which protocol it is by now */
335		break;
336	}
337
338	return success;
339}
340
341static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
342			struct packet_type *pt, struct net_device *orig_dev)
343{
344	struct net *net = dev_net(dev);
345	struct mpls_shim_hdr *hdr;
346	struct mpls_route *rt;
347	struct mpls_nh *nh;
348	struct mpls_entry_decoded dec;
349	struct net_device *out_dev;
350	struct mpls_dev *out_mdev;
351	struct mpls_dev *mdev;
352	unsigned int hh_len;
353	unsigned int new_header_size;
354	unsigned int mtu;
355	int err;
356
357	/* Careful this entire function runs inside of an rcu critical section */
358
359	mdev = mpls_dev_get(dev);
360	if (!mdev)
361		goto drop;
362
363	MPLS_INC_STATS_LEN(mdev, skb->len, rx_packets,
364			   rx_bytes);
365
366	if (!mdev->input_enabled) {
367		MPLS_INC_STATS(mdev, rx_dropped);
368		goto drop;
369	}
370
371	if (skb->pkt_type != PACKET_HOST)
372		goto err;
373
374	if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
375		goto err;
376
377	if (!pskb_may_pull(skb, sizeof(*hdr)))
378		goto err;
379
380	/* Read and decode the label */
381	hdr = mpls_hdr(skb);
382	dec = mpls_entry_decode(hdr);
383
384	rt = mpls_route_input_rcu(net, dec.label);
385	if (!rt) {
386		MPLS_INC_STATS(mdev, rx_noroute);
387		goto drop;
388	}
389
390	nh = mpls_select_multipath(rt, skb);
391	if (!nh)
392		goto err;
393
394	/* Pop the label */
395	skb_pull(skb, sizeof(*hdr));
396	skb_reset_network_header(skb);
397
398	skb_orphan(skb);
399
400	if (skb_warn_if_lro(skb))
401		goto err;
402
403	skb_forward_csum(skb);
404
405	/* Verify ttl is valid */
406	if (dec.ttl <= 1)
407		goto err;
408	dec.ttl -= 1;
409
410	/* Find the output device */
411	out_dev = rcu_dereference(nh->nh_dev);
412	if (!mpls_output_possible(out_dev))
413		goto tx_err;
414
415	/* Verify the destination can hold the packet */
416	new_header_size = mpls_nh_header_size(nh);
417	mtu = mpls_dev_mtu(out_dev);
418	if (mpls_pkt_too_big(skb, mtu - new_header_size))
419		goto tx_err;
420
421	hh_len = LL_RESERVED_SPACE(out_dev);
422	if (!out_dev->header_ops)
423		hh_len = 0;
424
425	/* Ensure there is enough space for the headers in the skb */
426	if (skb_cow(skb, hh_len + new_header_size))
427		goto tx_err;
428
429	skb->dev = out_dev;
430	skb->protocol = htons(ETH_P_MPLS_UC);
431
432	if (unlikely(!new_header_size && dec.bos)) {
433		/* Penultimate hop popping */
434		if (!mpls_egress(dev_net(out_dev), rt, skb, dec))
435			goto err;
436	} else {
437		bool bos;
438		int i;
439		skb_push(skb, new_header_size);
440		skb_reset_network_header(skb);
441		/* Push the new labels */
442		hdr = mpls_hdr(skb);
443		bos = dec.bos;
444		for (i = nh->nh_labels - 1; i >= 0; i--) {
445			hdr[i] = mpls_entry_encode(nh->nh_label[i],
446						   dec.ttl, 0, bos);
447			bos = false;
448		}
449	}
450
451	mpls_stats_inc_outucastpkts(out_dev, skb);
452
453	/* If via wasn't specified then send out using device address */
454	if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
455		err = neigh_xmit(NEIGH_LINK_TABLE, out_dev,
456				 out_dev->dev_addr, skb);
457	else
458		err = neigh_xmit(nh->nh_via_table, out_dev,
459				 mpls_nh_via(rt, nh), skb);
460	if (err)
461		net_dbg_ratelimited("%s: packet transmission failed: %d\n",
462				    __func__, err);
463	return 0;
464
465tx_err:
466	out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL;
467	if (out_mdev)
468		MPLS_INC_STATS(out_mdev, tx_errors);
469	goto drop;
470err:
471	MPLS_INC_STATS(mdev, rx_errors);
472drop:
473	kfree_skb(skb);
474	return NET_RX_DROP;
475}
476
477static struct packet_type mpls_packet_type __read_mostly = {
478	.type = cpu_to_be16(ETH_P_MPLS_UC),
479	.func = mpls_forward,
480};
481
482static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
483	[RTA_DST]		= { .type = NLA_U32 },
484	[RTA_OIF]		= { .type = NLA_U32 },
485	[RTA_TTL_PROPAGATE]	= { .type = NLA_U8 },
486};
487
488struct mpls_route_config {
489	u32			rc_protocol;
490	u32			rc_ifindex;
491	u8			rc_via_table;
492	u8			rc_via_alen;
493	u8			rc_via[MAX_VIA_ALEN];
494	u32			rc_label;
495	u8			rc_ttl_propagate;
496	u8			rc_output_labels;
497	u32			rc_output_label[MAX_NEW_LABELS];
498	u32			rc_nlflags;
499	enum mpls_payload_type	rc_payload_type;
500	struct nl_info		rc_nlinfo;
501	struct rtnexthop	*rc_mp;
502	int			rc_mp_len;
503};
504
505/* all nexthops within a route have the same size based on max
506 * number of labels and max via length for a hop
507 */
508static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
509{
510	u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen);
511	struct mpls_route *rt;
512	size_t size;
513
514	size = sizeof(*rt) + num_nh * nh_size;
515	if (size > MAX_MPLS_ROUTE_MEM)
516		return ERR_PTR(-EINVAL);
517
518	rt = kzalloc(size, GFP_KERNEL);
519	if (!rt)
520		return ERR_PTR(-ENOMEM);
521
522	rt->rt_nhn = num_nh;
523	rt->rt_nhn_alive = num_nh;
524	rt->rt_nh_size = nh_size;
525	rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
526
527	return rt;
528}
529
530static void mpls_rt_free(struct mpls_route *rt)
531{
532	if (rt)
533		kfree_rcu(rt, rt_rcu);
534}
535
536static void mpls_notify_route(struct net *net, unsigned index,
537			      struct mpls_route *old, struct mpls_route *new,
538			      const struct nl_info *info)
539{
540	struct nlmsghdr *nlh = info ? info->nlh : NULL;
541	unsigned portid = info ? info->portid : 0;
542	int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
543	struct mpls_route *rt = new ? new : old;
544	unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
545	/* Ignore reserved labels for now */
546	if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
547		rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
548}
549
550static void mpls_route_update(struct net *net, unsigned index,
551			      struct mpls_route *new,
552			      const struct nl_info *info)
553{
554	struct mpls_route __rcu **platform_label;
555	struct mpls_route *rt;
556
557	ASSERT_RTNL();
558
559	platform_label = rtnl_dereference(net->mpls.platform_label);
560	rt = rtnl_dereference(platform_label[index]);
561	rcu_assign_pointer(platform_label[index], new);
562
563	mpls_notify_route(net, index, rt, new, info);
564
565	/* If we removed a route free it now */
566	mpls_rt_free(rt);
567}
568
569static unsigned find_free_label(struct net *net)
570{
571	struct mpls_route __rcu **platform_label;
572	size_t platform_labels;
573	unsigned index;
574
575	platform_label = rtnl_dereference(net->mpls.platform_label);
576	platform_labels = net->mpls.platform_labels;
577	for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
578	     index++) {
579		if (!rtnl_dereference(platform_label[index]))
580			return index;
581	}
582	return LABEL_NOT_SPECIFIED;
583}
584
585#if IS_ENABLED(CONFIG_INET)
586static struct net_device *inet_fib_lookup_dev(struct net *net,
587					      const void *addr)
588{
589	struct net_device *dev;
590	struct rtable *rt;
591	struct in_addr daddr;
592
593	memcpy(&daddr, addr, sizeof(struct in_addr));
594	rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
595	if (IS_ERR(rt))
596		return ERR_CAST(rt);
597
598	dev = rt->dst.dev;
599	dev_hold(dev);
600
601	ip_rt_put(rt);
602
603	return dev;
604}
605#else
606static struct net_device *inet_fib_lookup_dev(struct net *net,
607					      const void *addr)
608{
609	return ERR_PTR(-EAFNOSUPPORT);
610}
611#endif
612
613#if IS_ENABLED(CONFIG_IPV6)
614static struct net_device *inet6_fib_lookup_dev(struct net *net,
615					       const void *addr)
616{
617	struct net_device *dev;
618	struct dst_entry *dst;
619	struct flowi6 fl6;
620
621	if (!ipv6_stub)
622		return ERR_PTR(-EAFNOSUPPORT);
623
624	memset(&fl6, 0, sizeof(fl6));
625	memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
626	dst = ipv6_stub->ipv6_dst_lookup_flow(net, NULL, &fl6, NULL);
627	if (IS_ERR(dst))
628		return ERR_CAST(dst);
629
630	dev = dst->dev;
631	dev_hold(dev);
632	dst_release(dst);
633
634	return dev;
635}
636#else
637static struct net_device *inet6_fib_lookup_dev(struct net *net,
638					       const void *addr)
639{
640	return ERR_PTR(-EAFNOSUPPORT);
641}
642#endif
643
644static struct net_device *find_outdev(struct net *net,
645				      struct mpls_route *rt,
646				      struct mpls_nh *nh, int oif)
647{
648	struct net_device *dev = NULL;
649
650	if (!oif) {
651		switch (nh->nh_via_table) {
652		case NEIGH_ARP_TABLE:
653			dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
654			break;
655		case NEIGH_ND_TABLE:
656			dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
657			break;
658		case NEIGH_LINK_TABLE:
659			break;
660		}
661	} else {
662		dev = dev_get_by_index(net, oif);
663	}
664
665	if (!dev)
666		return ERR_PTR(-ENODEV);
667
668	if (IS_ERR(dev))
669		return dev;
670
671	/* The caller is holding rtnl anyways, so release the dev reference */
672	dev_put(dev);
673
674	return dev;
675}
676
677static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
678			      struct mpls_nh *nh, int oif)
679{
680	struct net_device *dev = NULL;
681	int err = -ENODEV;
682
683	dev = find_outdev(net, rt, nh, oif);
684	if (IS_ERR(dev)) {
685		err = PTR_ERR(dev);
686		dev = NULL;
687		goto errout;
688	}
689
690	/* Ensure this is a supported device */
691	err = -EINVAL;
692	if (!mpls_dev_get(dev))
693		goto errout;
694
695	if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
696	    (dev->addr_len != nh->nh_via_alen))
697		goto errout;
698
699	RCU_INIT_POINTER(nh->nh_dev, dev);
700
701	if (!(dev->flags & IFF_UP)) {
702		nh->nh_flags |= RTNH_F_DEAD;
703	} else {
704		unsigned int flags;
705
706		flags = dev_get_flags(dev);
707		if (!(flags & (IFF_RUNNING | IFF_LOWER_UP)))
708			nh->nh_flags |= RTNH_F_LINKDOWN;
709	}
710
711	return 0;
712
713errout:
714	return err;
715}
716
717static int nla_get_via(const struct nlattr *nla, u8 *via_alen, u8 *via_table,
718		       u8 via_addr[], struct netlink_ext_ack *extack)
719{
720	struct rtvia *via = nla_data(nla);
721	int err = -EINVAL;
722	int alen;
723
724	if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr)) {
725		NL_SET_ERR_MSG_ATTR(extack, nla,
726				    "Invalid attribute length for RTA_VIA");
727		goto errout;
728	}
729	alen = nla_len(nla) -
730			offsetof(struct rtvia, rtvia_addr);
731	if (alen > MAX_VIA_ALEN) {
732		NL_SET_ERR_MSG_ATTR(extack, nla,
733				    "Invalid address length for RTA_VIA");
734		goto errout;
735	}
736
737	/* Validate the address family */
738	switch (via->rtvia_family) {
739	case AF_PACKET:
740		*via_table = NEIGH_LINK_TABLE;
741		break;
742	case AF_INET:
743		*via_table = NEIGH_ARP_TABLE;
744		if (alen != 4)
745			goto errout;
746		break;
747	case AF_INET6:
748		*via_table = NEIGH_ND_TABLE;
749		if (alen != 16)
750			goto errout;
751		break;
752	default:
753		/* Unsupported address family */
754		goto errout;
755	}
756
757	memcpy(via_addr, via->rtvia_addr, alen);
758	*via_alen = alen;
759	err = 0;
760
761errout:
762	return err;
763}
764
765static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg,
766				  struct mpls_route *rt)
767{
768	struct net *net = cfg->rc_nlinfo.nl_net;
769	struct mpls_nh *nh = rt->rt_nh;
770	int err;
771	int i;
772
773	if (!nh)
774		return -ENOMEM;
775
776	nh->nh_labels = cfg->rc_output_labels;
777	for (i = 0; i < nh->nh_labels; i++)
778		nh->nh_label[i] = cfg->rc_output_label[i];
779
780	nh->nh_via_table = cfg->rc_via_table;
781	memcpy(__mpls_nh_via(rt, nh), cfg->rc_via, cfg->rc_via_alen);
782	nh->nh_via_alen = cfg->rc_via_alen;
783
784	err = mpls_nh_assign_dev(net, rt, nh, cfg->rc_ifindex);
785	if (err)
786		goto errout;
787
788	if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
789		rt->rt_nhn_alive--;
790
791	return 0;
792
793errout:
794	return err;
795}
796
797static int mpls_nh_build(struct net *net, struct mpls_route *rt,
798			 struct mpls_nh *nh, int oif, struct nlattr *via,
799			 struct nlattr *newdst, u8 max_labels,
800			 struct netlink_ext_ack *extack)
801{
802	int err = -ENOMEM;
803
804	if (!nh)
805		goto errout;
806
807	if (newdst) {
808		err = nla_get_labels(newdst, max_labels, &nh->nh_labels,
809				     nh->nh_label, extack);
810		if (err)
811			goto errout;
812	}
813
814	if (via) {
815		err = nla_get_via(via, &nh->nh_via_alen, &nh->nh_via_table,
816				  __mpls_nh_via(rt, nh), extack);
817		if (err)
818			goto errout;
819	} else {
820		nh->nh_via_table = MPLS_NEIGH_TABLE_UNSPEC;
821	}
822
823	err = mpls_nh_assign_dev(net, rt, nh, oif);
824	if (err)
825		goto errout;
826
827	return 0;
828
829errout:
830	return err;
831}
832
833static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len,
834			      u8 cfg_via_alen, u8 *max_via_alen,
835			      u8 *max_labels)
836{
837	int remaining = len;
838	u8 nhs = 0;
839
840	*max_via_alen = 0;
841	*max_labels = 0;
842
843	while (rtnh_ok(rtnh, remaining)) {
844		struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
845		int attrlen;
846		u8 n_labels = 0;
847
848		attrlen = rtnh_attrlen(rtnh);
849		nla = nla_find(attrs, attrlen, RTA_VIA);
850		if (nla && nla_len(nla) >=
851		    offsetof(struct rtvia, rtvia_addr)) {
852			int via_alen = nla_len(nla) -
853				offsetof(struct rtvia, rtvia_addr);
854
855			if (via_alen <= MAX_VIA_ALEN)
856				*max_via_alen = max_t(u16, *max_via_alen,
857						      via_alen);
858		}
859
860		nla = nla_find(attrs, attrlen, RTA_NEWDST);
861		if (nla &&
862		    nla_get_labels(nla, MAX_NEW_LABELS, &n_labels,
863				   NULL, NULL) != 0)
864			return 0;
865
866		*max_labels = max_t(u8, *max_labels, n_labels);
867
868		/* number of nexthops is tracked by a u8.
869		 * Check for overflow.
870		 */
871		if (nhs == 255)
872			return 0;
873		nhs++;
874
875		rtnh = rtnh_next(rtnh, &remaining);
876	}
877
878	/* leftover implies invalid nexthop configuration, discard it */
879	return remaining > 0 ? 0 : nhs;
880}
881
882static int mpls_nh_build_multi(struct mpls_route_config *cfg,
883			       struct mpls_route *rt, u8 max_labels,
884			       struct netlink_ext_ack *extack)
885{
886	struct rtnexthop *rtnh = cfg->rc_mp;
887	struct nlattr *nla_via, *nla_newdst;
888	int remaining = cfg->rc_mp_len;
889	int err = 0;
890	u8 nhs = 0;
891
892	change_nexthops(rt) {
893		int attrlen;
894
895		nla_via = NULL;
896		nla_newdst = NULL;
897
898		err = -EINVAL;
899		if (!rtnh_ok(rtnh, remaining))
900			goto errout;
901
902		/* neither weighted multipath nor any flags
903		 * are supported
904		 */
905		if (rtnh->rtnh_hops || rtnh->rtnh_flags)
906			goto errout;
907
908		attrlen = rtnh_attrlen(rtnh);
909		if (attrlen > 0) {
910			struct nlattr *attrs = rtnh_attrs(rtnh);
911
912			nla_via = nla_find(attrs, attrlen, RTA_VIA);
913			nla_newdst = nla_find(attrs, attrlen, RTA_NEWDST);
914		}
915
916		err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
917				    rtnh->rtnh_ifindex, nla_via, nla_newdst,
918				    max_labels, extack);
919		if (err)
920			goto errout;
921
922		if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
923			rt->rt_nhn_alive--;
924
925		rtnh = rtnh_next(rtnh, &remaining);
926		nhs++;
927	} endfor_nexthops(rt);
928
929	rt->rt_nhn = nhs;
930
931	return 0;
932
933errout:
934	return err;
935}
936
937static bool mpls_label_ok(struct net *net, unsigned int *index,
938			  struct netlink_ext_ack *extack)
939{
940	bool is_ok = true;
941
942	/* Reserved labels may not be set */
943	if (*index < MPLS_LABEL_FIRST_UNRESERVED) {
944		NL_SET_ERR_MSG(extack,
945			       "Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher");
946		is_ok = false;
947	}
948
949	/* The full 20 bit range may not be supported. */
950	if (is_ok && *index >= net->mpls.platform_labels) {
951		NL_SET_ERR_MSG(extack,
952			       "Label >= configured maximum in platform_labels");
953		is_ok = false;
954	}
955
956	*index = array_index_nospec(*index, net->mpls.platform_labels);
957	return is_ok;
958}
959
960static int mpls_route_add(struct mpls_route_config *cfg,
961			  struct netlink_ext_ack *extack)
962{
963	struct mpls_route __rcu **platform_label;
964	struct net *net = cfg->rc_nlinfo.nl_net;
965	struct mpls_route *rt, *old;
966	int err = -EINVAL;
967	u8 max_via_alen;
968	unsigned index;
969	u8 max_labels;
970	u8 nhs;
971
972	index = cfg->rc_label;
973
974	/* If a label was not specified during insert pick one */
975	if ((index == LABEL_NOT_SPECIFIED) &&
976	    (cfg->rc_nlflags & NLM_F_CREATE)) {
977		index = find_free_label(net);
978	}
979
980	if (!mpls_label_ok(net, &index, extack))
981		goto errout;
982
983	/* Append makes no sense with mpls */
984	err = -EOPNOTSUPP;
985	if (cfg->rc_nlflags & NLM_F_APPEND) {
986		NL_SET_ERR_MSG(extack, "MPLS does not support route append");
987		goto errout;
988	}
989
990	err = -EEXIST;
991	platform_label = rtnl_dereference(net->mpls.platform_label);
992	old = rtnl_dereference(platform_label[index]);
993	if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
994		goto errout;
995
996	err = -EEXIST;
997	if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
998		goto errout;
999
1000	err = -ENOENT;
1001	if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
1002		goto errout;
1003
1004	err = -EINVAL;
1005	if (cfg->rc_mp) {
1006		nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
1007					  cfg->rc_via_alen, &max_via_alen,
1008					  &max_labels);
1009	} else {
1010		max_via_alen = cfg->rc_via_alen;
1011		max_labels = cfg->rc_output_labels;
1012		nhs = 1;
1013	}
1014
1015	if (nhs == 0) {
1016		NL_SET_ERR_MSG(extack, "Route does not contain a nexthop");
1017		goto errout;
1018	}
1019
1020	err = -ENOMEM;
1021	rt = mpls_rt_alloc(nhs, max_via_alen, max_labels);
1022	if (IS_ERR(rt)) {
1023		err = PTR_ERR(rt);
1024		goto errout;
1025	}
1026
1027	rt->rt_protocol = cfg->rc_protocol;
1028	rt->rt_payload_type = cfg->rc_payload_type;
1029	rt->rt_ttl_propagate = cfg->rc_ttl_propagate;
1030
1031	if (cfg->rc_mp)
1032		err = mpls_nh_build_multi(cfg, rt, max_labels, extack);
1033	else
1034		err = mpls_nh_build_from_cfg(cfg, rt);
1035	if (err)
1036		goto freert;
1037
1038	mpls_route_update(net, index, rt, &cfg->rc_nlinfo);
1039
1040	return 0;
1041
1042freert:
1043	mpls_rt_free(rt);
1044errout:
1045	return err;
1046}
1047
1048static int mpls_route_del(struct mpls_route_config *cfg,
1049			  struct netlink_ext_ack *extack)
1050{
1051	struct net *net = cfg->rc_nlinfo.nl_net;
1052	unsigned index;
1053	int err = -EINVAL;
1054
1055	index = cfg->rc_label;
1056
1057	if (!mpls_label_ok(net, &index, extack))
1058		goto errout;
1059
1060	mpls_route_update(net, index, NULL, &cfg->rc_nlinfo);
1061
1062	err = 0;
1063errout:
1064	return err;
1065}
1066
1067static void mpls_get_stats(struct mpls_dev *mdev,
1068			   struct mpls_link_stats *stats)
1069{
1070	struct mpls_pcpu_stats *p;
1071	int i;
1072
1073	memset(stats, 0, sizeof(*stats));
1074
1075	for_each_possible_cpu(i) {
1076		struct mpls_link_stats local;
1077		unsigned int start;
1078
1079		p = per_cpu_ptr(mdev->stats, i);
1080		do {
1081			start = u64_stats_fetch_begin_irq(&p->syncp);
1082			local = p->stats;
1083		} while (u64_stats_fetch_retry_irq(&p->syncp, start));
1084
1085		stats->rx_packets	+= local.rx_packets;
1086		stats->rx_bytes		+= local.rx_bytes;
1087		stats->tx_packets	+= local.tx_packets;
1088		stats->tx_bytes		+= local.tx_bytes;
1089		stats->rx_errors	+= local.rx_errors;
1090		stats->tx_errors	+= local.tx_errors;
1091		stats->rx_dropped	+= local.rx_dropped;
1092		stats->tx_dropped	+= local.tx_dropped;
1093		stats->rx_noroute	+= local.rx_noroute;
1094	}
1095}
1096
1097static int mpls_fill_stats_af(struct sk_buff *skb,
1098			      const struct net_device *dev)
1099{
1100	struct mpls_link_stats *stats;
1101	struct mpls_dev *mdev;
1102	struct nlattr *nla;
1103
1104	mdev = mpls_dev_get(dev);
1105	if (!mdev)
1106		return -ENODATA;
1107
1108	nla = nla_reserve_64bit(skb, MPLS_STATS_LINK,
1109				sizeof(struct mpls_link_stats),
1110				MPLS_STATS_UNSPEC);
1111	if (!nla)
1112		return -EMSGSIZE;
1113
1114	stats = nla_data(nla);
1115	mpls_get_stats(mdev, stats);
1116
1117	return 0;
1118}
1119
1120static size_t mpls_get_stats_af_size(const struct net_device *dev)
1121{
1122	struct mpls_dev *mdev;
1123
1124	mdev = mpls_dev_get(dev);
1125	if (!mdev)
1126		return 0;
1127
1128	return nla_total_size_64bit(sizeof(struct mpls_link_stats));
1129}
1130
1131static int mpls_netconf_fill_devconf(struct sk_buff *skb, struct mpls_dev *mdev,
1132				     u32 portid, u32 seq, int event,
1133				     unsigned int flags, int type)
1134{
1135	struct nlmsghdr  *nlh;
1136	struct netconfmsg *ncm;
1137	bool all = false;
1138
1139	nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg),
1140			flags);
1141	if (!nlh)
1142		return -EMSGSIZE;
1143
1144	if (type == NETCONFA_ALL)
1145		all = true;
1146
1147	ncm = nlmsg_data(nlh);
1148	ncm->ncm_family = AF_MPLS;
1149
1150	if (nla_put_s32(skb, NETCONFA_IFINDEX, mdev->dev->ifindex) < 0)
1151		goto nla_put_failure;
1152
1153	if ((all || type == NETCONFA_INPUT) &&
1154	    nla_put_s32(skb, NETCONFA_INPUT,
1155			mdev->input_enabled) < 0)
1156		goto nla_put_failure;
1157
1158	nlmsg_end(skb, nlh);
1159	return 0;
1160
1161nla_put_failure:
1162	nlmsg_cancel(skb, nlh);
1163	return -EMSGSIZE;
1164}
1165
1166static int mpls_netconf_msgsize_devconf(int type)
1167{
1168	int size = NLMSG_ALIGN(sizeof(struct netconfmsg))
1169			+ nla_total_size(4); /* NETCONFA_IFINDEX */
1170	bool all = false;
1171
1172	if (type == NETCONFA_ALL)
1173		all = true;
1174
1175	if (all || type == NETCONFA_INPUT)
1176		size += nla_total_size(4);
1177
1178	return size;
1179}
1180
1181static void mpls_netconf_notify_devconf(struct net *net, int event,
1182					int type, struct mpls_dev *mdev)
1183{
1184	struct sk_buff *skb;
1185	int err = -ENOBUFS;
1186
1187	skb = nlmsg_new(mpls_netconf_msgsize_devconf(type), GFP_KERNEL);
1188	if (!skb)
1189		goto errout;
1190
1191	err = mpls_netconf_fill_devconf(skb, mdev, 0, 0, event, 0, type);
1192	if (err < 0) {
1193		/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
1194		WARN_ON(err == -EMSGSIZE);
1195		kfree_skb(skb);
1196		goto errout;
1197	}
1198
1199	rtnl_notify(skb, net, 0, RTNLGRP_MPLS_NETCONF, NULL, GFP_KERNEL);
1200	return;
1201errout:
1202	if (err < 0)
1203		rtnl_set_sk_err(net, RTNLGRP_MPLS_NETCONF, err);
1204}
1205
1206static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = {
1207	[NETCONFA_IFINDEX]	= { .len = sizeof(int) },
1208};
1209
1210static int mpls_netconf_valid_get_req(struct sk_buff *skb,
1211				      const struct nlmsghdr *nlh,
1212				      struct nlattr **tb,
1213				      struct netlink_ext_ack *extack)
1214{
1215	int i, err;
1216
1217	if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) {
1218		NL_SET_ERR_MSG_MOD(extack,
1219				   "Invalid header for netconf get request");
1220		return -EINVAL;
1221	}
1222
1223	if (!netlink_strict_get_check(skb))
1224		return nlmsg_parse_deprecated(nlh, sizeof(struct netconfmsg),
1225					      tb, NETCONFA_MAX,
1226					      devconf_mpls_policy, extack);
1227
1228	err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct netconfmsg),
1229					    tb, NETCONFA_MAX,
1230					    devconf_mpls_policy, extack);
1231	if (err)
1232		return err;
1233
1234	for (i = 0; i <= NETCONFA_MAX; i++) {
1235		if (!tb[i])
1236			continue;
1237
1238		switch (i) {
1239		case NETCONFA_IFINDEX:
1240			break;
1241		default:
1242			NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request");
1243			return -EINVAL;
1244		}
1245	}
1246
1247	return 0;
1248}
1249
1250static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
1251				    struct nlmsghdr *nlh,
1252				    struct netlink_ext_ack *extack)
1253{
1254	struct net *net = sock_net(in_skb->sk);
1255	struct nlattr *tb[NETCONFA_MAX + 1];
1256	struct net_device *dev;
1257	struct mpls_dev *mdev;
1258	struct sk_buff *skb;
1259	int ifindex;
1260	int err;
1261
1262	err = mpls_netconf_valid_get_req(in_skb, nlh, tb, extack);
1263	if (err < 0)
1264		goto errout;
1265
1266	err = -EINVAL;
1267	if (!tb[NETCONFA_IFINDEX])
1268		goto errout;
1269
1270	ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]);
1271	dev = __dev_get_by_index(net, ifindex);
1272	if (!dev)
1273		goto errout;
1274
1275	mdev = mpls_dev_get(dev);
1276	if (!mdev)
1277		goto errout;
1278
1279	err = -ENOBUFS;
1280	skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL);
1281	if (!skb)
1282		goto errout;
1283
1284	err = mpls_netconf_fill_devconf(skb, mdev,
1285					NETLINK_CB(in_skb).portid,
1286					nlh->nlmsg_seq, RTM_NEWNETCONF, 0,
1287					NETCONFA_ALL);
1288	if (err < 0) {
1289		/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
1290		WARN_ON(err == -EMSGSIZE);
1291		kfree_skb(skb);
1292		goto errout;
1293	}
1294	err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
1295errout:
1296	return err;
1297}
1298
1299static int mpls_netconf_dump_devconf(struct sk_buff *skb,
1300				     struct netlink_callback *cb)
1301{
1302	const struct nlmsghdr *nlh = cb->nlh;
1303	struct net *net = sock_net(skb->sk);
1304	struct hlist_head *head;
1305	struct net_device *dev;
1306	struct mpls_dev *mdev;
1307	int idx, s_idx;
1308	int h, s_h;
1309
1310	if (cb->strict_check) {
1311		struct netlink_ext_ack *extack = cb->extack;
1312		struct netconfmsg *ncm;
1313
1314		if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) {
1315			NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request");
1316			return -EINVAL;
1317		}
1318
1319		if (nlmsg_attrlen(nlh, sizeof(*ncm))) {
1320			NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request");
1321			return -EINVAL;
1322		}
1323	}
1324
1325	s_h = cb->args[0];
1326	s_idx = idx = cb->args[1];
1327
1328	for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) {
1329		idx = 0;
1330		head = &net->dev_index_head[h];
1331		rcu_read_lock();
1332		cb->seq = net->dev_base_seq;
1333		hlist_for_each_entry_rcu(dev, head, index_hlist) {
1334			if (idx < s_idx)
1335				goto cont;
1336			mdev = mpls_dev_get(dev);
1337			if (!mdev)
1338				goto cont;
1339			if (mpls_netconf_fill_devconf(skb, mdev,
1340						      NETLINK_CB(cb->skb).portid,
1341						      nlh->nlmsg_seq,
1342						      RTM_NEWNETCONF,
1343						      NLM_F_MULTI,
1344						      NETCONFA_ALL) < 0) {
1345				rcu_read_unlock();
1346				goto done;
1347			}
1348			nl_dump_check_consistent(cb, nlmsg_hdr(skb));
1349cont:
1350			idx++;
1351		}
1352		rcu_read_unlock();
1353	}
1354done:
1355	cb->args[0] = h;
1356	cb->args[1] = idx;
1357
1358	return skb->len;
1359}
1360
1361#define MPLS_PERDEV_SYSCTL_OFFSET(field)	\
1362	(&((struct mpls_dev *)0)->field)
1363
1364static int mpls_conf_proc(struct ctl_table *ctl, int write,
1365			  void *buffer, size_t *lenp, loff_t *ppos)
1366{
1367	int oval = *(int *)ctl->data;
1368	int ret = proc_dointvec(ctl, write, buffer, lenp, ppos);
1369
1370	if (write) {
1371		struct mpls_dev *mdev = ctl->extra1;
1372		int i = (int *)ctl->data - (int *)mdev;
1373		struct net *net = ctl->extra2;
1374		int val = *(int *)ctl->data;
1375
1376		if (i == offsetof(struct mpls_dev, input_enabled) &&
1377		    val != oval) {
1378			mpls_netconf_notify_devconf(net, RTM_NEWNETCONF,
1379						    NETCONFA_INPUT, mdev);
1380		}
1381	}
1382
1383	return ret;
1384}
1385
1386static const struct ctl_table mpls_dev_table[] = {
1387	{
1388		.procname	= "input",
1389		.maxlen		= sizeof(int),
1390		.mode		= 0644,
1391		.proc_handler	= mpls_conf_proc,
1392		.data		= MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
1393	},
1394	{ }
1395};
1396
1397static int mpls_dev_sysctl_register(struct net_device *dev,
1398				    struct mpls_dev *mdev)
1399{
1400	char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
1401	struct net *net = dev_net(dev);
1402	struct ctl_table *table;
1403	int i;
1404
1405	table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
1406	if (!table)
1407		goto out;
1408
1409	/* Table data contains only offsets relative to the base of
1410	 * the mdev at this point, so make them absolute.
1411	 */
1412	for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++) {
1413		table[i].data = (char *)mdev + (uintptr_t)table[i].data;
1414		table[i].extra1 = mdev;
1415		table[i].extra2 = net;
1416	}
1417
1418	snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
1419
1420	mdev->sysctl = register_net_sysctl(net, path, table);
1421	if (!mdev->sysctl)
1422		goto free;
1423
1424	mpls_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_ALL, mdev);
1425	return 0;
1426
1427free:
1428	kfree(table);
1429out:
1430	mdev->sysctl = NULL;
1431	return -ENOBUFS;
1432}
1433
1434static void mpls_dev_sysctl_unregister(struct net_device *dev,
1435				       struct mpls_dev *mdev)
1436{
1437	struct net *net = dev_net(dev);
1438	struct ctl_table *table;
1439
1440	if (!mdev->sysctl)
1441		return;
1442
1443	table = mdev->sysctl->ctl_table_arg;
1444	unregister_net_sysctl_table(mdev->sysctl);
1445	kfree(table);
1446
1447	mpls_netconf_notify_devconf(net, RTM_DELNETCONF, 0, mdev);
1448}
1449
1450static struct mpls_dev *mpls_add_dev(struct net_device *dev)
1451{
1452	struct mpls_dev *mdev;
1453	int err = -ENOMEM;
1454	int i;
1455
1456	ASSERT_RTNL();
1457
1458	mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
1459	if (!mdev)
1460		return ERR_PTR(err);
1461
1462	mdev->stats = alloc_percpu(struct mpls_pcpu_stats);
1463	if (!mdev->stats)
1464		goto free;
1465
1466	for_each_possible_cpu(i) {
1467		struct mpls_pcpu_stats *mpls_stats;
1468
1469		mpls_stats = per_cpu_ptr(mdev->stats, i);
1470		u64_stats_init(&mpls_stats->syncp);
1471	}
1472
1473	mdev->dev = dev;
1474
1475	err = mpls_dev_sysctl_register(dev, mdev);
1476	if (err)
1477		goto free;
1478
1479	rcu_assign_pointer(dev->mpls_ptr, mdev);
1480
1481	return mdev;
1482
1483free:
1484	free_percpu(mdev->stats);
1485	kfree(mdev);
1486	return ERR_PTR(err);
1487}
1488
1489static void mpls_dev_destroy_rcu(struct rcu_head *head)
1490{
1491	struct mpls_dev *mdev = container_of(head, struct mpls_dev, rcu);
1492
1493	free_percpu(mdev->stats);
1494	kfree(mdev);
1495}
1496
1497static int mpls_ifdown(struct net_device *dev, int event)
1498{
1499	struct mpls_route __rcu **platform_label;
1500	struct net *net = dev_net(dev);
1501	unsigned index;
1502
1503	platform_label = rtnl_dereference(net->mpls.platform_label);
1504	for (index = 0; index < net->mpls.platform_labels; index++) {
1505		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1506		bool nh_del = false;
1507		u8 alive = 0;
1508
1509		if (!rt)
1510			continue;
1511
1512		if (event == NETDEV_UNREGISTER) {
1513			u8 deleted = 0;
1514
1515			for_nexthops(rt) {
1516				struct net_device *nh_dev =
1517					rtnl_dereference(nh->nh_dev);
1518
1519				if (!nh_dev || nh_dev == dev)
1520					deleted++;
1521				if (nh_dev == dev)
1522					nh_del = true;
1523			} endfor_nexthops(rt);
1524
1525			/* if there are no more nexthops, delete the route */
1526			if (deleted == rt->rt_nhn) {
1527				mpls_route_update(net, index, NULL, NULL);
1528				continue;
1529			}
1530
1531			if (nh_del) {
1532				size_t size = sizeof(*rt) + rt->rt_nhn *
1533					rt->rt_nh_size;
1534				struct mpls_route *orig = rt;
1535
1536				rt = kmalloc(size, GFP_KERNEL);
1537				if (!rt)
1538					return -ENOMEM;
1539				memcpy(rt, orig, size);
1540			}
1541		}
1542
1543		change_nexthops(rt) {
1544			unsigned int nh_flags = nh->nh_flags;
1545
1546			if (rtnl_dereference(nh->nh_dev) != dev)
1547				goto next;
1548
1549			switch (event) {
1550			case NETDEV_DOWN:
1551			case NETDEV_UNREGISTER:
1552				nh_flags |= RTNH_F_DEAD;
1553				fallthrough;
1554			case NETDEV_CHANGE:
1555				nh_flags |= RTNH_F_LINKDOWN;
1556				break;
1557			}
1558			if (event == NETDEV_UNREGISTER)
1559				RCU_INIT_POINTER(nh->nh_dev, NULL);
1560
1561			if (nh->nh_flags != nh_flags)
1562				WRITE_ONCE(nh->nh_flags, nh_flags);
1563next:
1564			if (!(nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)))
1565				alive++;
1566		} endfor_nexthops(rt);
1567
1568		WRITE_ONCE(rt->rt_nhn_alive, alive);
1569
1570		if (nh_del)
1571			mpls_route_update(net, index, rt, NULL);
1572	}
1573
1574	return 0;
1575}
1576
1577static void mpls_ifup(struct net_device *dev, unsigned int flags)
1578{
1579	struct mpls_route __rcu **platform_label;
1580	struct net *net = dev_net(dev);
1581	unsigned index;
1582	u8 alive;
1583
1584	platform_label = rtnl_dereference(net->mpls.platform_label);
1585	for (index = 0; index < net->mpls.platform_labels; index++) {
1586		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1587
1588		if (!rt)
1589			continue;
1590
1591		alive = 0;
1592		change_nexthops(rt) {
1593			unsigned int nh_flags = nh->nh_flags;
1594			struct net_device *nh_dev =
1595				rtnl_dereference(nh->nh_dev);
1596
1597			if (!(nh_flags & flags)) {
1598				alive++;
1599				continue;
1600			}
1601			if (nh_dev != dev)
1602				continue;
1603			alive++;
1604			nh_flags &= ~flags;
1605			WRITE_ONCE(nh->nh_flags, nh_flags);
1606		} endfor_nexthops(rt);
1607
1608		WRITE_ONCE(rt->rt_nhn_alive, alive);
1609	}
1610}
1611
1612static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
1613			   void *ptr)
1614{
1615	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1616	struct mpls_dev *mdev;
1617	unsigned int flags;
1618
1619	if (event == NETDEV_REGISTER) {
1620		mdev = mpls_add_dev(dev);
1621		if (IS_ERR(mdev))
1622			return notifier_from_errno(PTR_ERR(mdev));
1623
1624		return NOTIFY_OK;
1625	}
1626
1627	mdev = mpls_dev_get(dev);
1628	if (!mdev)
1629		return NOTIFY_OK;
1630
1631	switch (event) {
1632		int err;
1633
1634	case NETDEV_DOWN:
1635		err = mpls_ifdown(dev, event);
1636		if (err)
1637			return notifier_from_errno(err);
1638		break;
1639	case NETDEV_UP:
1640		flags = dev_get_flags(dev);
1641		if (flags & (IFF_RUNNING | IFF_LOWER_UP))
1642			mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1643		else
1644			mpls_ifup(dev, RTNH_F_DEAD);
1645		break;
1646	case NETDEV_CHANGE:
1647		flags = dev_get_flags(dev);
1648		if (flags & (IFF_RUNNING | IFF_LOWER_UP)) {
1649			mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1650		} else {
1651			err = mpls_ifdown(dev, event);
1652			if (err)
1653				return notifier_from_errno(err);
1654		}
1655		break;
1656	case NETDEV_UNREGISTER:
1657		err = mpls_ifdown(dev, event);
1658		if (err)
1659			return notifier_from_errno(err);
1660		mdev = mpls_dev_get(dev);
1661		if (mdev) {
1662			mpls_dev_sysctl_unregister(dev, mdev);
1663			RCU_INIT_POINTER(dev->mpls_ptr, NULL);
1664			call_rcu(&mdev->rcu, mpls_dev_destroy_rcu);
1665		}
1666		break;
1667	case NETDEV_CHANGENAME:
1668		mdev = mpls_dev_get(dev);
1669		if (mdev) {
1670			mpls_dev_sysctl_unregister(dev, mdev);
1671			err = mpls_dev_sysctl_register(dev, mdev);
1672			if (err)
1673				return notifier_from_errno(err);
1674		}
1675		break;
1676	}
1677	return NOTIFY_OK;
1678}
1679
1680static struct notifier_block mpls_dev_notifier = {
1681	.notifier_call = mpls_dev_notify,
1682};
1683
1684static int nla_put_via(struct sk_buff *skb,
1685		       u8 table, const void *addr, int alen)
1686{
1687	static const int table_to_family[NEIGH_NR_TABLES + 1] = {
1688		AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
1689	};
1690	struct nlattr *nla;
1691	struct rtvia *via;
1692	int family = AF_UNSPEC;
1693
1694	nla = nla_reserve(skb, RTA_VIA, alen + 2);
1695	if (!nla)
1696		return -EMSGSIZE;
1697
1698	if (table <= NEIGH_NR_TABLES)
1699		family = table_to_family[table];
1700
1701	via = nla_data(nla);
1702	via->rtvia_family = family;
1703	memcpy(via->rtvia_addr, addr, alen);
1704	return 0;
1705}
1706
1707int nla_put_labels(struct sk_buff *skb, int attrtype,
1708		   u8 labels, const u32 label[])
1709{
1710	struct nlattr *nla;
1711	struct mpls_shim_hdr *nla_label;
1712	bool bos;
1713	int i;
1714	nla = nla_reserve(skb, attrtype, labels*4);
1715	if (!nla)
1716		return -EMSGSIZE;
1717
1718	nla_label = nla_data(nla);
1719	bos = true;
1720	for (i = labels - 1; i >= 0; i--) {
1721		nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
1722		bos = false;
1723	}
1724
1725	return 0;
1726}
1727EXPORT_SYMBOL_GPL(nla_put_labels);
1728
1729int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels,
1730		   u32 label[], struct netlink_ext_ack *extack)
1731{
1732	unsigned len = nla_len(nla);
1733	struct mpls_shim_hdr *nla_label;
1734	u8 nla_labels;
1735	bool bos;
1736	int i;
1737
1738	/* len needs to be an even multiple of 4 (the label size). Number
1739	 * of labels is a u8 so check for overflow.
1740	 */
1741	if (len & 3 || len / 4 > 255) {
1742		NL_SET_ERR_MSG_ATTR(extack, nla,
1743				    "Invalid length for labels attribute");
1744		return -EINVAL;
1745	}
1746
1747	/* Limit the number of new labels allowed */
1748	nla_labels = len/4;
1749	if (nla_labels > max_labels) {
1750		NL_SET_ERR_MSG(extack, "Too many labels");
1751		return -EINVAL;
1752	}
1753
1754	/* when label == NULL, caller wants number of labels */
1755	if (!label)
1756		goto out;
1757
1758	nla_label = nla_data(nla);
1759	bos = true;
1760	for (i = nla_labels - 1; i >= 0; i--, bos = false) {
1761		struct mpls_entry_decoded dec;
1762		dec = mpls_entry_decode(nla_label + i);
1763
1764		/* Ensure the bottom of stack flag is properly set
1765		 * and ttl and tc are both clear.
1766		 */
1767		if (dec.ttl) {
1768			NL_SET_ERR_MSG_ATTR(extack, nla,
1769					    "TTL in label must be 0");
1770			return -EINVAL;
1771		}
1772
1773		if (dec.tc) {
1774			NL_SET_ERR_MSG_ATTR(extack, nla,
1775					    "Traffic class in label must be 0");
1776			return -EINVAL;
1777		}
1778
1779		if (dec.bos != bos) {
1780			NL_SET_BAD_ATTR(extack, nla);
1781			if (bos) {
1782				NL_SET_ERR_MSG(extack,
1783					       "BOS bit must be set in first label");
1784			} else {
1785				NL_SET_ERR_MSG(extack,
1786					       "BOS bit can only be set in first label");
1787			}
1788			return -EINVAL;
1789		}
1790
1791		switch (dec.label) {
1792		case MPLS_LABEL_IMPLNULL:
1793			/* RFC3032: This is a label that an LSR may
1794			 * assign and distribute, but which never
1795			 * actually appears in the encapsulation.
1796			 */
1797			NL_SET_ERR_MSG_ATTR(extack, nla,
1798					    "Implicit NULL Label (3) can not be used in encapsulation");
1799			return -EINVAL;
1800		}
1801
1802		label[i] = dec.label;
1803	}
1804out:
1805	*labels = nla_labels;
1806	return 0;
1807}
1808EXPORT_SYMBOL_GPL(nla_get_labels);
1809
1810static int rtm_to_route_config(struct sk_buff *skb,
1811			       struct nlmsghdr *nlh,
1812			       struct mpls_route_config *cfg,
1813			       struct netlink_ext_ack *extack)
1814{
1815	struct rtmsg *rtm;
1816	struct nlattr *tb[RTA_MAX+1];
1817	int index;
1818	int err;
1819
1820	err = nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX,
1821				     rtm_mpls_policy, extack);
1822	if (err < 0)
1823		goto errout;
1824
1825	err = -EINVAL;
1826	rtm = nlmsg_data(nlh);
1827
1828	if (rtm->rtm_family != AF_MPLS) {
1829		NL_SET_ERR_MSG(extack, "Invalid address family in rtmsg");
1830		goto errout;
1831	}
1832	if (rtm->rtm_dst_len != 20) {
1833		NL_SET_ERR_MSG(extack, "rtm_dst_len must be 20 for MPLS");
1834		goto errout;
1835	}
1836	if (rtm->rtm_src_len != 0) {
1837		NL_SET_ERR_MSG(extack, "rtm_src_len must be 0 for MPLS");
1838		goto errout;
1839	}
1840	if (rtm->rtm_tos != 0) {
1841		NL_SET_ERR_MSG(extack, "rtm_tos must be 0 for MPLS");
1842		goto errout;
1843	}
1844	if (rtm->rtm_table != RT_TABLE_MAIN) {
1845		NL_SET_ERR_MSG(extack,
1846			       "MPLS only supports the main route table");
1847		goto errout;
1848	}
1849	/* Any value is acceptable for rtm_protocol */
1850
1851	/* As mpls uses destination specific addresses
1852	 * (or source specific address in the case of multicast)
1853	 * all addresses have universal scope.
1854	 */
1855	if (rtm->rtm_scope != RT_SCOPE_UNIVERSE) {
1856		NL_SET_ERR_MSG(extack,
1857			       "Invalid route scope  - MPLS only supports UNIVERSE");
1858		goto errout;
1859	}
1860	if (rtm->rtm_type != RTN_UNICAST) {
1861		NL_SET_ERR_MSG(extack,
1862			       "Invalid route type - MPLS only supports UNICAST");
1863		goto errout;
1864	}
1865	if (rtm->rtm_flags != 0) {
1866		NL_SET_ERR_MSG(extack, "rtm_flags must be 0 for MPLS");
1867		goto errout;
1868	}
1869
1870	cfg->rc_label		= LABEL_NOT_SPECIFIED;
1871	cfg->rc_protocol	= rtm->rtm_protocol;
1872	cfg->rc_via_table	= MPLS_NEIGH_TABLE_UNSPEC;
1873	cfg->rc_ttl_propagate	= MPLS_TTL_PROP_DEFAULT;
1874	cfg->rc_nlflags		= nlh->nlmsg_flags;
1875	cfg->rc_nlinfo.portid	= NETLINK_CB(skb).portid;
1876	cfg->rc_nlinfo.nlh	= nlh;
1877	cfg->rc_nlinfo.nl_net	= sock_net(skb->sk);
1878
1879	for (index = 0; index <= RTA_MAX; index++) {
1880		struct nlattr *nla = tb[index];
1881		if (!nla)
1882			continue;
1883
1884		switch (index) {
1885		case RTA_OIF:
1886			cfg->rc_ifindex = nla_get_u32(nla);
1887			break;
1888		case RTA_NEWDST:
1889			if (nla_get_labels(nla, MAX_NEW_LABELS,
1890					   &cfg->rc_output_labels,
1891					   cfg->rc_output_label, extack))
1892				goto errout;
1893			break;
1894		case RTA_DST:
1895		{
1896			u8 label_count;
1897			if (nla_get_labels(nla, 1, &label_count,
1898					   &cfg->rc_label, extack))
1899				goto errout;
1900
1901			if (!mpls_label_ok(cfg->rc_nlinfo.nl_net,
1902					   &cfg->rc_label, extack))
1903				goto errout;
1904			break;
1905		}
1906		case RTA_GATEWAY:
1907			NL_SET_ERR_MSG(extack, "MPLS does not support RTA_GATEWAY attribute");
1908			goto errout;
1909		case RTA_VIA:
1910		{
1911			if (nla_get_via(nla, &cfg->rc_via_alen,
1912					&cfg->rc_via_table, cfg->rc_via,
1913					extack))
1914				goto errout;
1915			break;
1916		}
1917		case RTA_MULTIPATH:
1918		{
1919			cfg->rc_mp = nla_data(nla);
1920			cfg->rc_mp_len = nla_len(nla);
1921			break;
1922		}
1923		case RTA_TTL_PROPAGATE:
1924		{
1925			u8 ttl_propagate = nla_get_u8(nla);
1926
1927			if (ttl_propagate > 1) {
1928				NL_SET_ERR_MSG_ATTR(extack, nla,
1929						    "RTA_TTL_PROPAGATE can only be 0 or 1");
1930				goto errout;
1931			}
1932			cfg->rc_ttl_propagate = ttl_propagate ?
1933				MPLS_TTL_PROP_ENABLED :
1934				MPLS_TTL_PROP_DISABLED;
1935			break;
1936		}
1937		default:
1938			NL_SET_ERR_MSG_ATTR(extack, nla, "Unknown attribute");
1939			/* Unsupported attribute */
1940			goto errout;
1941		}
1942	}
1943
1944	err = 0;
1945errout:
1946	return err;
1947}
1948
1949static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1950			     struct netlink_ext_ack *extack)
1951{
1952	struct mpls_route_config *cfg;
1953	int err;
1954
1955	cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
1956	if (!cfg)
1957		return -ENOMEM;
1958
1959	err = rtm_to_route_config(skb, nlh, cfg, extack);
1960	if (err < 0)
1961		goto out;
1962
1963	err = mpls_route_del(cfg, extack);
1964out:
1965	kfree(cfg);
1966
1967	return err;
1968}
1969
1970
1971static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1972			     struct netlink_ext_ack *extack)
1973{
1974	struct mpls_route_config *cfg;
1975	int err;
1976
1977	cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
1978	if (!cfg)
1979		return -ENOMEM;
1980
1981	err = rtm_to_route_config(skb, nlh, cfg, extack);
1982	if (err < 0)
1983		goto out;
1984
1985	err = mpls_route_add(cfg, extack);
1986out:
1987	kfree(cfg);
1988
1989	return err;
1990}
1991
1992static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
1993			   u32 label, struct mpls_route *rt, int flags)
1994{
1995	struct net_device *dev;
1996	struct nlmsghdr *nlh;
1997	struct rtmsg *rtm;
1998
1999	nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
2000	if (nlh == NULL)
2001		return -EMSGSIZE;
2002
2003	rtm = nlmsg_data(nlh);
2004	rtm->rtm_family = AF_MPLS;
2005	rtm->rtm_dst_len = 20;
2006	rtm->rtm_src_len = 0;
2007	rtm->rtm_tos = 0;
2008	rtm->rtm_table = RT_TABLE_MAIN;
2009	rtm->rtm_protocol = rt->rt_protocol;
2010	rtm->rtm_scope = RT_SCOPE_UNIVERSE;
2011	rtm->rtm_type = RTN_UNICAST;
2012	rtm->rtm_flags = 0;
2013
2014	if (nla_put_labels(skb, RTA_DST, 1, &label))
2015		goto nla_put_failure;
2016
2017	if (rt->rt_ttl_propagate != MPLS_TTL_PROP_DEFAULT) {
2018		bool ttl_propagate =
2019			rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED;
2020
2021		if (nla_put_u8(skb, RTA_TTL_PROPAGATE,
2022			       ttl_propagate))
2023			goto nla_put_failure;
2024	}
2025	if (rt->rt_nhn == 1) {
2026		const struct mpls_nh *nh = rt->rt_nh;
2027
2028		if (nh->nh_labels &&
2029		    nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
2030				   nh->nh_label))
2031			goto nla_put_failure;
2032		if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2033		    nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
2034				nh->nh_via_alen))
2035			goto nla_put_failure;
2036		dev = rtnl_dereference(nh->nh_dev);
2037		if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
2038			goto nla_put_failure;
2039		if (nh->nh_flags & RTNH_F_LINKDOWN)
2040			rtm->rtm_flags |= RTNH_F_LINKDOWN;
2041		if (nh->nh_flags & RTNH_F_DEAD)
2042			rtm->rtm_flags |= RTNH_F_DEAD;
2043	} else {
2044		struct rtnexthop *rtnh;
2045		struct nlattr *mp;
2046		u8 linkdown = 0;
2047		u8 dead = 0;
2048
2049		mp = nla_nest_start_noflag(skb, RTA_MULTIPATH);
2050		if (!mp)
2051			goto nla_put_failure;
2052
2053		for_nexthops(rt) {
2054			dev = rtnl_dereference(nh->nh_dev);
2055			if (!dev)
2056				continue;
2057
2058			rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh));
2059			if (!rtnh)
2060				goto nla_put_failure;
2061
2062			rtnh->rtnh_ifindex = dev->ifindex;
2063			if (nh->nh_flags & RTNH_F_LINKDOWN) {
2064				rtnh->rtnh_flags |= RTNH_F_LINKDOWN;
2065				linkdown++;
2066			}
2067			if (nh->nh_flags & RTNH_F_DEAD) {
2068				rtnh->rtnh_flags |= RTNH_F_DEAD;
2069				dead++;
2070			}
2071
2072			if (nh->nh_labels && nla_put_labels(skb, RTA_NEWDST,
2073							    nh->nh_labels,
2074							    nh->nh_label))
2075				goto nla_put_failure;
2076			if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2077			    nla_put_via(skb, nh->nh_via_table,
2078					mpls_nh_via(rt, nh),
2079					nh->nh_via_alen))
2080				goto nla_put_failure;
2081
2082			/* length of rtnetlink header + attributes */
2083			rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh;
2084		} endfor_nexthops(rt);
2085
2086		if (linkdown == rt->rt_nhn)
2087			rtm->rtm_flags |= RTNH_F_LINKDOWN;
2088		if (dead == rt->rt_nhn)
2089			rtm->rtm_flags |= RTNH_F_DEAD;
2090
2091		nla_nest_end(skb, mp);
2092	}
2093
2094	nlmsg_end(skb, nlh);
2095	return 0;
2096
2097nla_put_failure:
2098	nlmsg_cancel(skb, nlh);
2099	return -EMSGSIZE;
2100}
2101
2102#if IS_ENABLED(CONFIG_INET)
2103static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
2104				   struct fib_dump_filter *filter,
2105				   struct netlink_callback *cb)
2106{
2107	return ip_valid_fib_dump_req(net, nlh, filter, cb);
2108}
2109#else
2110static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
2111				   struct fib_dump_filter *filter,
2112				   struct netlink_callback *cb)
2113{
2114	struct netlink_ext_ack *extack = cb->extack;
2115	struct nlattr *tb[RTA_MAX + 1];
2116	struct rtmsg *rtm;
2117	int err, i;
2118
2119	if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
2120		NL_SET_ERR_MSG_MOD(extack, "Invalid header for FIB dump request");
2121		return -EINVAL;
2122	}
2123
2124	rtm = nlmsg_data(nlh);
2125	if (rtm->rtm_dst_len || rtm->rtm_src_len  || rtm->rtm_tos   ||
2126	    rtm->rtm_table   || rtm->rtm_scope    || rtm->rtm_type  ||
2127	    rtm->rtm_flags) {
2128		NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for FIB dump request");
2129		return -EINVAL;
2130	}
2131
2132	if (rtm->rtm_protocol) {
2133		filter->protocol = rtm->rtm_protocol;
2134		filter->filter_set = 1;
2135		cb->answer_flags = NLM_F_DUMP_FILTERED;
2136	}
2137
2138	err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
2139					    rtm_mpls_policy, extack);
2140	if (err < 0)
2141		return err;
2142
2143	for (i = 0; i <= RTA_MAX; ++i) {
2144		int ifindex;
2145
2146		if (i == RTA_OIF) {
2147			ifindex = nla_get_u32(tb[i]);
2148			filter->dev = __dev_get_by_index(net, ifindex);
2149			if (!filter->dev)
2150				return -ENODEV;
2151			filter->filter_set = 1;
2152		} else if (tb[i]) {
2153			NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request");
2154			return -EINVAL;
2155		}
2156	}
2157
2158	return 0;
2159}
2160#endif
2161
2162static bool mpls_rt_uses_dev(struct mpls_route *rt,
2163			     const struct net_device *dev)
2164{
2165	struct net_device *nh_dev;
2166
2167	if (rt->rt_nhn == 1) {
2168		struct mpls_nh *nh = rt->rt_nh;
2169
2170		nh_dev = rtnl_dereference(nh->nh_dev);
2171		if (dev == nh_dev)
2172			return true;
2173	} else {
2174		for_nexthops(rt) {
2175			nh_dev = rtnl_dereference(nh->nh_dev);
2176			if (nh_dev == dev)
2177				return true;
2178		} endfor_nexthops(rt);
2179	}
2180
2181	return false;
2182}
2183
2184static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
2185{
2186	const struct nlmsghdr *nlh = cb->nlh;
2187	struct net *net = sock_net(skb->sk);
2188	struct mpls_route __rcu **platform_label;
2189	struct fib_dump_filter filter = {};
2190	unsigned int flags = NLM_F_MULTI;
2191	size_t platform_labels;
2192	unsigned int index;
2193
2194	ASSERT_RTNL();
2195
2196	if (cb->strict_check) {
2197		int err;
2198
2199		err = mpls_valid_fib_dump_req(net, nlh, &filter, cb);
2200		if (err < 0)
2201			return err;
2202
2203		/* for MPLS, there is only 1 table with fixed type and flags.
2204		 * If either are set in the filter then return nothing.
2205		 */
2206		if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) ||
2207		    (filter.rt_type && filter.rt_type != RTN_UNICAST) ||
2208		     filter.flags)
2209			return skb->len;
2210	}
2211
2212	index = cb->args[0];
2213	if (index < MPLS_LABEL_FIRST_UNRESERVED)
2214		index = MPLS_LABEL_FIRST_UNRESERVED;
2215
2216	platform_label = rtnl_dereference(net->mpls.platform_label);
2217	platform_labels = net->mpls.platform_labels;
2218
2219	if (filter.filter_set)
2220		flags |= NLM_F_DUMP_FILTERED;
2221
2222	for (; index < platform_labels; index++) {
2223		struct mpls_route *rt;
2224
2225		rt = rtnl_dereference(platform_label[index]);
2226		if (!rt)
2227			continue;
2228
2229		if ((filter.dev && !mpls_rt_uses_dev(rt, filter.dev)) ||
2230		    (filter.protocol && rt->rt_protocol != filter.protocol))
2231			continue;
2232
2233		if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
2234				    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
2235				    index, rt, flags) < 0)
2236			break;
2237	}
2238	cb->args[0] = index;
2239
2240	return skb->len;
2241}
2242
2243static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
2244{
2245	size_t payload =
2246		NLMSG_ALIGN(sizeof(struct rtmsg))
2247		+ nla_total_size(4)			/* RTA_DST */
2248		+ nla_total_size(1);			/* RTA_TTL_PROPAGATE */
2249
2250	if (rt->rt_nhn == 1) {
2251		struct mpls_nh *nh = rt->rt_nh;
2252
2253		if (nh->nh_dev)
2254			payload += nla_total_size(4); /* RTA_OIF */
2255		if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) /* RTA_VIA */
2256			payload += nla_total_size(2 + nh->nh_via_alen);
2257		if (nh->nh_labels) /* RTA_NEWDST */
2258			payload += nla_total_size(nh->nh_labels * 4);
2259	} else {
2260		/* each nexthop is packed in an attribute */
2261		size_t nhsize = 0;
2262
2263		for_nexthops(rt) {
2264			if (!rtnl_dereference(nh->nh_dev))
2265				continue;
2266			nhsize += nla_total_size(sizeof(struct rtnexthop));
2267			/* RTA_VIA */
2268			if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC)
2269				nhsize += nla_total_size(2 + nh->nh_via_alen);
2270			if (nh->nh_labels)
2271				nhsize += nla_total_size(nh->nh_labels * 4);
2272		} endfor_nexthops(rt);
2273		/* nested attribute */
2274		payload += nla_total_size(nhsize);
2275	}
2276
2277	return payload;
2278}
2279
2280static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
2281		       struct nlmsghdr *nlh, struct net *net, u32 portid,
2282		       unsigned int nlm_flags)
2283{
2284	struct sk_buff *skb;
2285	u32 seq = nlh ? nlh->nlmsg_seq : 0;
2286	int err = -ENOBUFS;
2287
2288	skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
2289	if (skb == NULL)
2290		goto errout;
2291
2292	err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
2293	if (err < 0) {
2294		/* -EMSGSIZE implies BUG in lfib_nlmsg_size */
2295		WARN_ON(err == -EMSGSIZE);
2296		kfree_skb(skb);
2297		goto errout;
2298	}
2299	rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
2300
2301	return;
2302errout:
2303	if (err < 0)
2304		rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
2305}
2306
2307static int mpls_valid_getroute_req(struct sk_buff *skb,
2308				   const struct nlmsghdr *nlh,
2309				   struct nlattr **tb,
2310				   struct netlink_ext_ack *extack)
2311{
2312	struct rtmsg *rtm;
2313	int i, err;
2314
2315	if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
2316		NL_SET_ERR_MSG_MOD(extack,
2317				   "Invalid header for get route request");
2318		return -EINVAL;
2319	}
2320
2321	if (!netlink_strict_get_check(skb))
2322		return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX,
2323					      rtm_mpls_policy, extack);
2324
2325	rtm = nlmsg_data(nlh);
2326	if ((rtm->rtm_dst_len && rtm->rtm_dst_len != 20) ||
2327	    rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_table ||
2328	    rtm->rtm_protocol || rtm->rtm_scope || rtm->rtm_type) {
2329		NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request");
2330		return -EINVAL;
2331	}
2332	if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) {
2333		NL_SET_ERR_MSG_MOD(extack,
2334				   "Invalid flags for get route request");
2335		return -EINVAL;
2336	}
2337
2338	err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
2339					    rtm_mpls_policy, extack);
2340	if (err)
2341		return err;
2342
2343	if ((tb[RTA_DST] || tb[RTA_NEWDST]) && !rtm->rtm_dst_len) {
2344		NL_SET_ERR_MSG_MOD(extack, "rtm_dst_len must be 20 for MPLS");
2345		return -EINVAL;
2346	}
2347
2348	for (i = 0; i <= RTA_MAX; i++) {
2349		if (!tb[i])
2350			continue;
2351
2352		switch (i) {
2353		case RTA_DST:
2354		case RTA_NEWDST:
2355			break;
2356		default:
2357			NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request");
2358			return -EINVAL;
2359		}
2360	}
2361
2362	return 0;
2363}
2364
2365static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
2366			 struct netlink_ext_ack *extack)
2367{
2368	struct net *net = sock_net(in_skb->sk);
2369	u32 portid = NETLINK_CB(in_skb).portid;
2370	u32 in_label = LABEL_NOT_SPECIFIED;
2371	struct nlattr *tb[RTA_MAX + 1];
2372	u32 labels[MAX_NEW_LABELS];
2373	struct mpls_shim_hdr *hdr;
2374	unsigned int hdr_size = 0;
2375	struct net_device *dev;
2376	struct mpls_route *rt;
2377	struct rtmsg *rtm, *r;
2378	struct nlmsghdr *nlh;
2379	struct sk_buff *skb;
2380	struct mpls_nh *nh;
2381	u8 n_labels;
2382	int err;
2383
2384	err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
2385	if (err < 0)
2386		goto errout;
2387
2388	rtm = nlmsg_data(in_nlh);
2389
2390	if (tb[RTA_DST]) {
2391		u8 label_count;
2392
2393		if (nla_get_labels(tb[RTA_DST], 1, &label_count,
2394				   &in_label, extack)) {
2395			err = -EINVAL;
2396			goto errout;
2397		}
2398
2399		if (!mpls_label_ok(net, &in_label, extack)) {
2400			err = -EINVAL;
2401			goto errout;
2402		}
2403	}
2404
2405	rt = mpls_route_input_rcu(net, in_label);
2406	if (!rt) {
2407		err = -ENETUNREACH;
2408		goto errout;
2409	}
2410
2411	if (rtm->rtm_flags & RTM_F_FIB_MATCH) {
2412		skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
2413		if (!skb) {
2414			err = -ENOBUFS;
2415			goto errout;
2416		}
2417
2418		err = mpls_dump_route(skb, portid, in_nlh->nlmsg_seq,
2419				      RTM_NEWROUTE, in_label, rt, 0);
2420		if (err < 0) {
2421			/* -EMSGSIZE implies BUG in lfib_nlmsg_size */
2422			WARN_ON(err == -EMSGSIZE);
2423			goto errout_free;
2424		}
2425
2426		return rtnl_unicast(skb, net, portid);
2427	}
2428
2429	if (tb[RTA_NEWDST]) {
2430		if (nla_get_labels(tb[RTA_NEWDST], MAX_NEW_LABELS, &n_labels,
2431				   labels, extack) != 0) {
2432			err = -EINVAL;
2433			goto errout;
2434		}
2435
2436		hdr_size = n_labels * sizeof(struct mpls_shim_hdr);
2437	}
2438
2439	skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL);
2440	if (!skb) {
2441		err = -ENOBUFS;
2442		goto errout;
2443	}
2444
2445	skb->protocol = htons(ETH_P_MPLS_UC);
2446
2447	if (hdr_size) {
2448		bool bos;
2449		int i;
2450
2451		if (skb_cow(skb, hdr_size)) {
2452			err = -ENOBUFS;
2453			goto errout_free;
2454		}
2455
2456		skb_reserve(skb, hdr_size);
2457		skb_push(skb, hdr_size);
2458		skb_reset_network_header(skb);
2459
2460		/* Push new labels */
2461		hdr = mpls_hdr(skb);
2462		bos = true;
2463		for (i = n_labels - 1; i >= 0; i--) {
2464			hdr[i] = mpls_entry_encode(labels[i],
2465						   1, 0, bos);
2466			bos = false;
2467		}
2468	}
2469
2470	nh = mpls_select_multipath(rt, skb);
2471	if (!nh) {
2472		err = -ENETUNREACH;
2473		goto errout_free;
2474	}
2475
2476	if (hdr_size) {
2477		skb_pull(skb, hdr_size);
2478		skb_reset_network_header(skb);
2479	}
2480
2481	nlh = nlmsg_put(skb, portid, in_nlh->nlmsg_seq,
2482			RTM_NEWROUTE, sizeof(*r), 0);
2483	if (!nlh) {
2484		err = -EMSGSIZE;
2485		goto errout_free;
2486	}
2487
2488	r = nlmsg_data(nlh);
2489	r->rtm_family	 = AF_MPLS;
2490	r->rtm_dst_len	= 20;
2491	r->rtm_src_len	= 0;
2492	r->rtm_table	= RT_TABLE_MAIN;
2493	r->rtm_type	= RTN_UNICAST;
2494	r->rtm_scope	= RT_SCOPE_UNIVERSE;
2495	r->rtm_protocol = rt->rt_protocol;
2496	r->rtm_flags	= 0;
2497
2498	if (nla_put_labels(skb, RTA_DST, 1, &in_label))
2499		goto nla_put_failure;
2500
2501	if (nh->nh_labels &&
2502	    nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
2503			   nh->nh_label))
2504		goto nla_put_failure;
2505
2506	if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2507	    nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
2508			nh->nh_via_alen))
2509		goto nla_put_failure;
2510	dev = rtnl_dereference(nh->nh_dev);
2511	if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
2512		goto nla_put_failure;
2513
2514	nlmsg_end(skb, nlh);
2515
2516	err = rtnl_unicast(skb, net, portid);
2517errout:
2518	return err;
2519
2520nla_put_failure:
2521	nlmsg_cancel(skb, nlh);
2522	err = -EMSGSIZE;
2523errout_free:
2524	kfree_skb(skb);
2525	return err;
2526}
2527
2528static int resize_platform_label_table(struct net *net, size_t limit)
2529{
2530	size_t size = sizeof(struct mpls_route *) * limit;
2531	size_t old_limit;
2532	size_t cp_size;
2533	struct mpls_route __rcu **labels = NULL, **old;
2534	struct mpls_route *rt0 = NULL, *rt2 = NULL;
2535	unsigned index;
2536
2537	if (size) {
2538		labels = kvzalloc(size, GFP_KERNEL);
2539		if (!labels)
2540			goto nolabels;
2541	}
2542
2543	/* In case the predefined labels need to be populated */
2544	if (limit > MPLS_LABEL_IPV4NULL) {
2545		struct net_device *lo = net->loopback_dev;
2546		rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
2547		if (IS_ERR(rt0))
2548			goto nort0;
2549		RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
2550		rt0->rt_protocol = RTPROT_KERNEL;
2551		rt0->rt_payload_type = MPT_IPV4;
2552		rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
2553		rt0->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
2554		rt0->rt_nh->nh_via_alen = lo->addr_len;
2555		memcpy(__mpls_nh_via(rt0, rt0->rt_nh), lo->dev_addr,
2556		       lo->addr_len);
2557	}
2558	if (limit > MPLS_LABEL_IPV6NULL) {
2559		struct net_device *lo = net->loopback_dev;
2560		rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
2561		if (IS_ERR(rt2))
2562			goto nort2;
2563		RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
2564		rt2->rt_protocol = RTPROT_KERNEL;
2565		rt2->rt_payload_type = MPT_IPV6;
2566		rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
2567		rt2->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
2568		rt2->rt_nh->nh_via_alen = lo->addr_len;
2569		memcpy(__mpls_nh_via(rt2, rt2->rt_nh), lo->dev_addr,
2570		       lo->addr_len);
2571	}
2572
2573	rtnl_lock();
2574	/* Remember the original table */
2575	old = rtnl_dereference(net->mpls.platform_label);
2576	old_limit = net->mpls.platform_labels;
2577
2578	/* Free any labels beyond the new table */
2579	for (index = limit; index < old_limit; index++)
2580		mpls_route_update(net, index, NULL, NULL);
2581
2582	/* Copy over the old labels */
2583	cp_size = size;
2584	if (old_limit < limit)
2585		cp_size = old_limit * sizeof(struct mpls_route *);
2586
2587	memcpy(labels, old, cp_size);
2588
2589	/* If needed set the predefined labels */
2590	if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
2591	    (limit > MPLS_LABEL_IPV6NULL)) {
2592		RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
2593		rt2 = NULL;
2594	}
2595
2596	if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
2597	    (limit > MPLS_LABEL_IPV4NULL)) {
2598		RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
2599		rt0 = NULL;
2600	}
2601
2602	/* Update the global pointers */
2603	net->mpls.platform_labels = limit;
2604	rcu_assign_pointer(net->mpls.platform_label, labels);
2605
2606	rtnl_unlock();
2607
2608	mpls_rt_free(rt2);
2609	mpls_rt_free(rt0);
2610
2611	if (old) {
2612		synchronize_rcu();
2613		kvfree(old);
2614	}
2615	return 0;
2616
2617nort2:
2618	mpls_rt_free(rt0);
2619nort0:
2620	kvfree(labels);
2621nolabels:
2622	return -ENOMEM;
2623}
2624
2625static int mpls_platform_labels(struct ctl_table *table, int write,
2626				void *buffer, size_t *lenp, loff_t *ppos)
2627{
2628	struct net *net = table->data;
2629	int platform_labels = net->mpls.platform_labels;
2630	int ret;
2631	struct ctl_table tmp = {
2632		.procname	= table->procname,
2633		.data		= &platform_labels,
2634		.maxlen		= sizeof(int),
2635		.mode		= table->mode,
2636		.extra1		= SYSCTL_ZERO,
2637		.extra2		= &label_limit,
2638	};
2639
2640	ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
2641
2642	if (write && ret == 0)
2643		ret = resize_platform_label_table(net, platform_labels);
2644
2645	return ret;
2646}
2647
2648#define MPLS_NS_SYSCTL_OFFSET(field)		\
2649	(&((struct net *)0)->field)
2650
2651static const struct ctl_table mpls_table[] = {
2652	{
2653		.procname	= "platform_labels",
2654		.data		= NULL,
2655		.maxlen		= sizeof(int),
2656		.mode		= 0644,
2657		.proc_handler	= mpls_platform_labels,
2658	},
2659	{
2660		.procname	= "ip_ttl_propagate",
2661		.data		= MPLS_NS_SYSCTL_OFFSET(mpls.ip_ttl_propagate),
2662		.maxlen		= sizeof(int),
2663		.mode		= 0644,
2664		.proc_handler	= proc_dointvec_minmax,
2665		.extra1		= SYSCTL_ZERO,
2666		.extra2		= SYSCTL_ONE,
2667	},
2668	{
2669		.procname	= "default_ttl",
2670		.data		= MPLS_NS_SYSCTL_OFFSET(mpls.default_ttl),
2671		.maxlen		= sizeof(int),
2672		.mode		= 0644,
2673		.proc_handler	= proc_dointvec_minmax,
2674		.extra1		= SYSCTL_ONE,
2675		.extra2		= &ttl_max,
2676	},
2677	{ }
2678};
2679
2680static int mpls_net_init(struct net *net)
2681{
2682	struct ctl_table *table;
2683	int i;
2684
2685	net->mpls.platform_labels = 0;
2686	net->mpls.platform_label = NULL;
2687	net->mpls.ip_ttl_propagate = 1;
2688	net->mpls.default_ttl = 255;
2689
2690	table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
2691	if (table == NULL)
2692		return -ENOMEM;
2693
2694	/* Table data contains only offsets relative to the base of
2695	 * the mdev at this point, so make them absolute.
2696	 */
2697	for (i = 0; i < ARRAY_SIZE(mpls_table) - 1; i++)
2698		table[i].data = (char *)net + (uintptr_t)table[i].data;
2699
2700	net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
2701	if (net->mpls.ctl == NULL) {
2702		kfree(table);
2703		return -ENOMEM;
2704	}
2705
2706	return 0;
2707}
2708
2709static void mpls_net_exit(struct net *net)
2710{
2711	struct mpls_route __rcu **platform_label;
2712	size_t platform_labels;
2713	struct ctl_table *table;
2714	unsigned int index;
2715
2716	table = net->mpls.ctl->ctl_table_arg;
2717	unregister_net_sysctl_table(net->mpls.ctl);
2718	kfree(table);
2719
2720	/* An rcu grace period has passed since there was a device in
2721	 * the network namespace (and thus the last in flight packet)
2722	 * left this network namespace.  This is because
2723	 * unregister_netdevice_many and netdev_run_todo has completed
2724	 * for each network device that was in this network namespace.
2725	 *
2726	 * As such no additional rcu synchronization is necessary when
2727	 * freeing the platform_label table.
2728	 */
2729	rtnl_lock();
2730	platform_label = rtnl_dereference(net->mpls.platform_label);
2731	platform_labels = net->mpls.platform_labels;
2732	for (index = 0; index < platform_labels; index++) {
2733		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
2734		RCU_INIT_POINTER(platform_label[index], NULL);
2735		mpls_notify_route(net, index, rt, NULL, NULL);
2736		mpls_rt_free(rt);
2737	}
2738	rtnl_unlock();
2739
2740	kvfree(platform_label);
2741}
2742
2743static struct pernet_operations mpls_net_ops = {
2744	.init = mpls_net_init,
2745	.exit = mpls_net_exit,
2746};
2747
2748static struct rtnl_af_ops mpls_af_ops __read_mostly = {
2749	.family		   = AF_MPLS,
2750	.fill_stats_af	   = mpls_fill_stats_af,
2751	.get_stats_af_size = mpls_get_stats_af_size,
2752};
2753
2754static int __init mpls_init(void)
2755{
2756	int err;
2757
2758	BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
2759
2760	err = register_pernet_subsys(&mpls_net_ops);
2761	if (err)
2762		goto out;
2763
2764	err = register_netdevice_notifier(&mpls_dev_notifier);
2765	if (err)
2766		goto out_unregister_pernet;
2767
2768	dev_add_pack(&mpls_packet_type);
2769
2770	rtnl_af_register(&mpls_af_ops);
2771
2772	rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_NEWROUTE,
2773			     mpls_rtm_newroute, NULL, 0);
2774	rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_DELROUTE,
2775			     mpls_rtm_delroute, NULL, 0);
2776	rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETROUTE,
2777			     mpls_getroute, mpls_dump_routes, 0);
2778	rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETNETCONF,
2779			     mpls_netconf_get_devconf,
2780			     mpls_netconf_dump_devconf, 0);
2781	err = ipgre_tunnel_encap_add_mpls_ops();
2782	if (err)
2783		pr_err("Can't add mpls over gre tunnel ops\n");
2784
2785	err = 0;
2786out:
2787	return err;
2788
2789out_unregister_pernet:
2790	unregister_pernet_subsys(&mpls_net_ops);
2791	goto out;
2792}
2793module_init(mpls_init);
2794
2795static void __exit mpls_exit(void)
2796{
2797	rtnl_unregister_all(PF_MPLS);
2798	rtnl_af_unregister(&mpls_af_ops);
2799	dev_remove_pack(&mpls_packet_type);
2800	unregister_netdevice_notifier(&mpls_dev_notifier);
2801	unregister_pernet_subsys(&mpls_net_ops);
2802	ipgre_tunnel_encap_del_mpls_ops();
2803}
2804module_exit(mpls_exit);
2805
2806MODULE_DESCRIPTION("MultiProtocol Label Switching");
2807MODULE_LICENSE("GPL v2");
2808MODULE_ALIAS_NETPROTO(PF_MPLS);
2809