162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
262306a36Sopenharmony_ci/* Copyright (c) 2019 Facebook */
362306a36Sopenharmony_ci
462306a36Sopenharmony_ci/* WARNING: This implemenation is not necessarily the same
562306a36Sopenharmony_ci * as the tcp_dctcp.c.  The purpose is mainly for testing
662306a36Sopenharmony_ci * the kernel BPF logic.
762306a36Sopenharmony_ci */
862306a36Sopenharmony_ci
962306a36Sopenharmony_ci#include <stddef.h>
1062306a36Sopenharmony_ci#include <linux/bpf.h>
1162306a36Sopenharmony_ci#include <linux/types.h>
1262306a36Sopenharmony_ci#include <linux/stddef.h>
1362306a36Sopenharmony_ci#include <linux/tcp.h>
1462306a36Sopenharmony_ci#include <errno.h>
1562306a36Sopenharmony_ci#include <bpf/bpf_helpers.h>
1662306a36Sopenharmony_ci#include <bpf/bpf_tracing.h>
1762306a36Sopenharmony_ci#include "bpf_tcp_helpers.h"
1862306a36Sopenharmony_ci
1962306a36Sopenharmony_cichar _license[] SEC("license") = "GPL";
2062306a36Sopenharmony_ci
2162306a36Sopenharmony_civolatile const char fallback[TCP_CA_NAME_MAX];
2262306a36Sopenharmony_ciconst char bpf_dctcp[] = "bpf_dctcp";
2362306a36Sopenharmony_ciconst char tcp_cdg[] = "cdg";
2462306a36Sopenharmony_cichar cc_res[TCP_CA_NAME_MAX];
2562306a36Sopenharmony_ciint tcp_cdg_res = 0;
2662306a36Sopenharmony_ciint stg_result = 0;
2762306a36Sopenharmony_ciint ebusy_cnt = 0;
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_cistruct {
3062306a36Sopenharmony_ci	__uint(type, BPF_MAP_TYPE_SK_STORAGE);
3162306a36Sopenharmony_ci	__uint(map_flags, BPF_F_NO_PREALLOC);
3262306a36Sopenharmony_ci	__type(key, int);
3362306a36Sopenharmony_ci	__type(value, int);
3462306a36Sopenharmony_ci} sk_stg_map SEC(".maps");
3562306a36Sopenharmony_ci
3662306a36Sopenharmony_ci#define DCTCP_MAX_ALPHA	1024U
3762306a36Sopenharmony_ci
3862306a36Sopenharmony_cistruct dctcp {
3962306a36Sopenharmony_ci	__u32 old_delivered;
4062306a36Sopenharmony_ci	__u32 old_delivered_ce;
4162306a36Sopenharmony_ci	__u32 prior_rcv_nxt;
4262306a36Sopenharmony_ci	__u32 dctcp_alpha;
4362306a36Sopenharmony_ci	__u32 next_seq;
4462306a36Sopenharmony_ci	__u32 ce_state;
4562306a36Sopenharmony_ci	__u32 loss_cwnd;
4662306a36Sopenharmony_ci};
4762306a36Sopenharmony_ci
4862306a36Sopenharmony_cistatic unsigned int dctcp_shift_g = 4; /* g = 1/2^4 */
4962306a36Sopenharmony_cistatic unsigned int dctcp_alpha_on_init = DCTCP_MAX_ALPHA;
5062306a36Sopenharmony_ci
5162306a36Sopenharmony_cistatic __always_inline void dctcp_reset(const struct tcp_sock *tp,
5262306a36Sopenharmony_ci					struct dctcp *ca)
5362306a36Sopenharmony_ci{
5462306a36Sopenharmony_ci	ca->next_seq = tp->snd_nxt;
5562306a36Sopenharmony_ci
5662306a36Sopenharmony_ci	ca->old_delivered = tp->delivered;
5762306a36Sopenharmony_ci	ca->old_delivered_ce = tp->delivered_ce;
5862306a36Sopenharmony_ci}
5962306a36Sopenharmony_ci
6062306a36Sopenharmony_ciSEC("struct_ops/dctcp_init")
6162306a36Sopenharmony_civoid BPF_PROG(dctcp_init, struct sock *sk)
6262306a36Sopenharmony_ci{
6362306a36Sopenharmony_ci	const struct tcp_sock *tp = tcp_sk(sk);
6462306a36Sopenharmony_ci	struct dctcp *ca = inet_csk_ca(sk);
6562306a36Sopenharmony_ci	int *stg;
6662306a36Sopenharmony_ci
6762306a36Sopenharmony_ci	if (!(tp->ecn_flags & TCP_ECN_OK) && fallback[0]) {
6862306a36Sopenharmony_ci		/* Switch to fallback */
6962306a36Sopenharmony_ci		if (bpf_setsockopt(sk, SOL_TCP, TCP_CONGESTION,
7062306a36Sopenharmony_ci				   (void *)fallback, sizeof(fallback)) == -EBUSY)
7162306a36Sopenharmony_ci			ebusy_cnt++;
7262306a36Sopenharmony_ci
7362306a36Sopenharmony_ci		/* Switch back to myself and the recurred dctcp_init()
7462306a36Sopenharmony_ci		 * will get -EBUSY for all bpf_setsockopt(TCP_CONGESTION),
7562306a36Sopenharmony_ci		 * except the last "cdg" one.
7662306a36Sopenharmony_ci		 */
7762306a36Sopenharmony_ci		if (bpf_setsockopt(sk, SOL_TCP, TCP_CONGESTION,
7862306a36Sopenharmony_ci				   (void *)bpf_dctcp, sizeof(bpf_dctcp)) == -EBUSY)
7962306a36Sopenharmony_ci			ebusy_cnt++;
8062306a36Sopenharmony_ci
8162306a36Sopenharmony_ci		/* Switch back to fallback */
8262306a36Sopenharmony_ci		if (bpf_setsockopt(sk, SOL_TCP, TCP_CONGESTION,
8362306a36Sopenharmony_ci				   (void *)fallback, sizeof(fallback)) == -EBUSY)
8462306a36Sopenharmony_ci			ebusy_cnt++;
8562306a36Sopenharmony_ci
8662306a36Sopenharmony_ci		/* Expecting -ENOTSUPP for tcp_cdg_res */
8762306a36Sopenharmony_ci		tcp_cdg_res = bpf_setsockopt(sk, SOL_TCP, TCP_CONGESTION,
8862306a36Sopenharmony_ci					     (void *)tcp_cdg, sizeof(tcp_cdg));
8962306a36Sopenharmony_ci		bpf_getsockopt(sk, SOL_TCP, TCP_CONGESTION,
9062306a36Sopenharmony_ci			       (void *)cc_res, sizeof(cc_res));
9162306a36Sopenharmony_ci		return;
9262306a36Sopenharmony_ci	}
9362306a36Sopenharmony_ci
9462306a36Sopenharmony_ci	ca->prior_rcv_nxt = tp->rcv_nxt;
9562306a36Sopenharmony_ci	ca->dctcp_alpha = min(dctcp_alpha_on_init, DCTCP_MAX_ALPHA);
9662306a36Sopenharmony_ci	ca->loss_cwnd = 0;
9762306a36Sopenharmony_ci	ca->ce_state = 0;
9862306a36Sopenharmony_ci
9962306a36Sopenharmony_ci	stg = bpf_sk_storage_get(&sk_stg_map, (void *)tp, NULL, 0);
10062306a36Sopenharmony_ci	if (stg) {
10162306a36Sopenharmony_ci		stg_result = *stg;
10262306a36Sopenharmony_ci		bpf_sk_storage_delete(&sk_stg_map, (void *)tp);
10362306a36Sopenharmony_ci	}
10462306a36Sopenharmony_ci	dctcp_reset(tp, ca);
10562306a36Sopenharmony_ci}
10662306a36Sopenharmony_ci
10762306a36Sopenharmony_ciSEC("struct_ops/dctcp_ssthresh")
10862306a36Sopenharmony_ci__u32 BPF_PROG(dctcp_ssthresh, struct sock *sk)
10962306a36Sopenharmony_ci{
11062306a36Sopenharmony_ci	struct dctcp *ca = inet_csk_ca(sk);
11162306a36Sopenharmony_ci	struct tcp_sock *tp = tcp_sk(sk);
11262306a36Sopenharmony_ci
11362306a36Sopenharmony_ci	ca->loss_cwnd = tp->snd_cwnd;
11462306a36Sopenharmony_ci	return max(tp->snd_cwnd - ((tp->snd_cwnd * ca->dctcp_alpha) >> 11U), 2U);
11562306a36Sopenharmony_ci}
11662306a36Sopenharmony_ci
11762306a36Sopenharmony_ciSEC("struct_ops/dctcp_update_alpha")
11862306a36Sopenharmony_civoid BPF_PROG(dctcp_update_alpha, struct sock *sk, __u32 flags)
11962306a36Sopenharmony_ci{
12062306a36Sopenharmony_ci	const struct tcp_sock *tp = tcp_sk(sk);
12162306a36Sopenharmony_ci	struct dctcp *ca = inet_csk_ca(sk);
12262306a36Sopenharmony_ci
12362306a36Sopenharmony_ci	/* Expired RTT */
12462306a36Sopenharmony_ci	if (!before(tp->snd_una, ca->next_seq)) {
12562306a36Sopenharmony_ci		__u32 delivered_ce = tp->delivered_ce - ca->old_delivered_ce;
12662306a36Sopenharmony_ci		__u32 alpha = ca->dctcp_alpha;
12762306a36Sopenharmony_ci
12862306a36Sopenharmony_ci		/* alpha = (1 - g) * alpha + g * F */
12962306a36Sopenharmony_ci
13062306a36Sopenharmony_ci		alpha -= min_not_zero(alpha, alpha >> dctcp_shift_g);
13162306a36Sopenharmony_ci		if (delivered_ce) {
13262306a36Sopenharmony_ci			__u32 delivered = tp->delivered - ca->old_delivered;
13362306a36Sopenharmony_ci
13462306a36Sopenharmony_ci			/* If dctcp_shift_g == 1, a 32bit value would overflow
13562306a36Sopenharmony_ci			 * after 8 M packets.
13662306a36Sopenharmony_ci			 */
13762306a36Sopenharmony_ci			delivered_ce <<= (10 - dctcp_shift_g);
13862306a36Sopenharmony_ci			delivered_ce /= max(1U, delivered);
13962306a36Sopenharmony_ci
14062306a36Sopenharmony_ci			alpha = min(alpha + delivered_ce, DCTCP_MAX_ALPHA);
14162306a36Sopenharmony_ci		}
14262306a36Sopenharmony_ci		ca->dctcp_alpha = alpha;
14362306a36Sopenharmony_ci		dctcp_reset(tp, ca);
14462306a36Sopenharmony_ci	}
14562306a36Sopenharmony_ci}
14662306a36Sopenharmony_ci
14762306a36Sopenharmony_cistatic __always_inline void dctcp_react_to_loss(struct sock *sk)
14862306a36Sopenharmony_ci{
14962306a36Sopenharmony_ci	struct dctcp *ca = inet_csk_ca(sk);
15062306a36Sopenharmony_ci	struct tcp_sock *tp = tcp_sk(sk);
15162306a36Sopenharmony_ci
15262306a36Sopenharmony_ci	ca->loss_cwnd = tp->snd_cwnd;
15362306a36Sopenharmony_ci	tp->snd_ssthresh = max(tp->snd_cwnd >> 1U, 2U);
15462306a36Sopenharmony_ci}
15562306a36Sopenharmony_ci
15662306a36Sopenharmony_ciSEC("struct_ops/dctcp_state")
15762306a36Sopenharmony_civoid BPF_PROG(dctcp_state, struct sock *sk, __u8 new_state)
15862306a36Sopenharmony_ci{
15962306a36Sopenharmony_ci	if (new_state == TCP_CA_Recovery &&
16062306a36Sopenharmony_ci	    new_state != BPF_CORE_READ_BITFIELD(inet_csk(sk), icsk_ca_state))
16162306a36Sopenharmony_ci		dctcp_react_to_loss(sk);
16262306a36Sopenharmony_ci	/* We handle RTO in dctcp_cwnd_event to ensure that we perform only
16362306a36Sopenharmony_ci	 * one loss-adjustment per RTT.
16462306a36Sopenharmony_ci	 */
16562306a36Sopenharmony_ci}
16662306a36Sopenharmony_ci
16762306a36Sopenharmony_cistatic __always_inline void dctcp_ece_ack_cwr(struct sock *sk, __u32 ce_state)
16862306a36Sopenharmony_ci{
16962306a36Sopenharmony_ci	struct tcp_sock *tp = tcp_sk(sk);
17062306a36Sopenharmony_ci
17162306a36Sopenharmony_ci	if (ce_state == 1)
17262306a36Sopenharmony_ci		tp->ecn_flags |= TCP_ECN_DEMAND_CWR;
17362306a36Sopenharmony_ci	else
17462306a36Sopenharmony_ci		tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR;
17562306a36Sopenharmony_ci}
17662306a36Sopenharmony_ci
17762306a36Sopenharmony_ci/* Minimal DCTP CE state machine:
17862306a36Sopenharmony_ci *
17962306a36Sopenharmony_ci * S:	0 <- last pkt was non-CE
18062306a36Sopenharmony_ci *	1 <- last pkt was CE
18162306a36Sopenharmony_ci */
18262306a36Sopenharmony_cistatic __always_inline
18362306a36Sopenharmony_civoid dctcp_ece_ack_update(struct sock *sk, enum tcp_ca_event evt,
18462306a36Sopenharmony_ci			  __u32 *prior_rcv_nxt, __u32 *ce_state)
18562306a36Sopenharmony_ci{
18662306a36Sopenharmony_ci	__u32 new_ce_state = (evt == CA_EVENT_ECN_IS_CE) ? 1 : 0;
18762306a36Sopenharmony_ci
18862306a36Sopenharmony_ci	if (*ce_state != new_ce_state) {
18962306a36Sopenharmony_ci		/* CE state has changed, force an immediate ACK to
19062306a36Sopenharmony_ci		 * reflect the new CE state. If an ACK was delayed,
19162306a36Sopenharmony_ci		 * send that first to reflect the prior CE state.
19262306a36Sopenharmony_ci		 */
19362306a36Sopenharmony_ci		if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) {
19462306a36Sopenharmony_ci			dctcp_ece_ack_cwr(sk, *ce_state);
19562306a36Sopenharmony_ci			bpf_tcp_send_ack(sk, *prior_rcv_nxt);
19662306a36Sopenharmony_ci		}
19762306a36Sopenharmony_ci		inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW;
19862306a36Sopenharmony_ci	}
19962306a36Sopenharmony_ci	*prior_rcv_nxt = tcp_sk(sk)->rcv_nxt;
20062306a36Sopenharmony_ci	*ce_state = new_ce_state;
20162306a36Sopenharmony_ci	dctcp_ece_ack_cwr(sk, new_ce_state);
20262306a36Sopenharmony_ci}
20362306a36Sopenharmony_ci
20462306a36Sopenharmony_ciSEC("struct_ops/dctcp_cwnd_event")
20562306a36Sopenharmony_civoid BPF_PROG(dctcp_cwnd_event, struct sock *sk, enum tcp_ca_event ev)
20662306a36Sopenharmony_ci{
20762306a36Sopenharmony_ci	struct dctcp *ca = inet_csk_ca(sk);
20862306a36Sopenharmony_ci
20962306a36Sopenharmony_ci	switch (ev) {
21062306a36Sopenharmony_ci	case CA_EVENT_ECN_IS_CE:
21162306a36Sopenharmony_ci	case CA_EVENT_ECN_NO_CE:
21262306a36Sopenharmony_ci		dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state);
21362306a36Sopenharmony_ci		break;
21462306a36Sopenharmony_ci	case CA_EVENT_LOSS:
21562306a36Sopenharmony_ci		dctcp_react_to_loss(sk);
21662306a36Sopenharmony_ci		break;
21762306a36Sopenharmony_ci	default:
21862306a36Sopenharmony_ci		/* Don't care for the rest. */
21962306a36Sopenharmony_ci		break;
22062306a36Sopenharmony_ci	}
22162306a36Sopenharmony_ci}
22262306a36Sopenharmony_ci
22362306a36Sopenharmony_ciSEC("struct_ops/dctcp_cwnd_undo")
22462306a36Sopenharmony_ci__u32 BPF_PROG(dctcp_cwnd_undo, struct sock *sk)
22562306a36Sopenharmony_ci{
22662306a36Sopenharmony_ci	const struct dctcp *ca = inet_csk_ca(sk);
22762306a36Sopenharmony_ci
22862306a36Sopenharmony_ci	return max(tcp_sk(sk)->snd_cwnd, ca->loss_cwnd);
22962306a36Sopenharmony_ci}
23062306a36Sopenharmony_ci
23162306a36Sopenharmony_ciextern void tcp_reno_cong_avoid(struct sock *sk, __u32 ack, __u32 acked) __ksym;
23262306a36Sopenharmony_ci
23362306a36Sopenharmony_ciSEC("struct_ops/dctcp_reno_cong_avoid")
23462306a36Sopenharmony_civoid BPF_PROG(dctcp_cong_avoid, struct sock *sk, __u32 ack, __u32 acked)
23562306a36Sopenharmony_ci{
23662306a36Sopenharmony_ci	tcp_reno_cong_avoid(sk, ack, acked);
23762306a36Sopenharmony_ci}
23862306a36Sopenharmony_ci
23962306a36Sopenharmony_ciSEC(".struct_ops")
24062306a36Sopenharmony_cistruct tcp_congestion_ops dctcp_nouse = {
24162306a36Sopenharmony_ci	.init		= (void *)dctcp_init,
24262306a36Sopenharmony_ci	.set_state	= (void *)dctcp_state,
24362306a36Sopenharmony_ci	.flags		= TCP_CONG_NEEDS_ECN,
24462306a36Sopenharmony_ci	.name		= "bpf_dctcp_nouse",
24562306a36Sopenharmony_ci};
24662306a36Sopenharmony_ci
24762306a36Sopenharmony_ciSEC(".struct_ops")
24862306a36Sopenharmony_cistruct tcp_congestion_ops dctcp = {
24962306a36Sopenharmony_ci	.init		= (void *)dctcp_init,
25062306a36Sopenharmony_ci	.in_ack_event   = (void *)dctcp_update_alpha,
25162306a36Sopenharmony_ci	.cwnd_event	= (void *)dctcp_cwnd_event,
25262306a36Sopenharmony_ci	.ssthresh	= (void *)dctcp_ssthresh,
25362306a36Sopenharmony_ci	.cong_avoid	= (void *)dctcp_cong_avoid,
25462306a36Sopenharmony_ci	.undo_cwnd	= (void *)dctcp_cwnd_undo,
25562306a36Sopenharmony_ci	.set_state	= (void *)dctcp_state,
25662306a36Sopenharmony_ci	.flags		= TCP_CONG_NEEDS_ECN,
25762306a36Sopenharmony_ci	.name		= "bpf_dctcp",
25862306a36Sopenharmony_ci};
259