1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2010-2013 Felix Fietkau <nbd@openwrt.org>
4 * Copyright (C) 2019-2020 Intel Corporation
5 */
6#include <linux/netdevice.h>
7#include <linux/types.h>
8#include <linux/skbuff.h>
9#include <linux/debugfs.h>
10#include <linux/random.h>
11#include <linux/moduleparam.h>
12#include <linux/ieee80211.h>
13#include <net/mac80211.h>
14#include "rate.h"
15#include "sta_info.h"
16#include "rc80211_minstrel.h"
17#include "rc80211_minstrel_ht.h"
18
19#define AVG_AMPDU_SIZE	16
20#define AVG_PKT_SIZE	1200
21
22#define SAMPLE_SWITCH_THR	100
23
24/* Number of bits for an average sized packet */
25#define MCS_NBITS ((AVG_PKT_SIZE * AVG_AMPDU_SIZE) << 3)
26
27/* Number of symbols for a packet with (bps) bits per symbol */
28#define MCS_NSYMS(bps) DIV_ROUND_UP(MCS_NBITS, (bps))
29
30/* Transmission time (nanoseconds) for a packet containing (syms) symbols */
31#define MCS_SYMBOL_TIME(sgi, syms)					\
32	(sgi ?								\
33	  ((syms) * 18000 + 4000) / 5 :	/* syms * 3.6 us */		\
34	  ((syms) * 1000) << 2		/* syms * 4 us */		\
35	)
36
37/* Transmit duration for the raw data part of an average sized packet */
38#define MCS_DURATION(streams, sgi, bps) \
39	(MCS_SYMBOL_TIME(sgi, MCS_NSYMS((streams) * (bps))) / AVG_AMPDU_SIZE)
40
41#define BW_20			0
42#define BW_40			1
43#define BW_80			2
44
45/*
46 * Define group sort order: HT40 -> SGI -> #streams
47 */
48#define GROUP_IDX(_streams, _sgi, _ht40)	\
49	MINSTREL_HT_GROUP_0 +			\
50	MINSTREL_MAX_STREAMS * 2 * _ht40 +	\
51	MINSTREL_MAX_STREAMS * _sgi +	\
52	_streams - 1
53
54#define _MAX(a, b) (((a)>(b))?(a):(b))
55
56#define GROUP_SHIFT(duration)						\
57	_MAX(0, 16 - __builtin_clz(duration))
58
59/* MCS rate information for an MCS group */
60#define __MCS_GROUP(_streams, _sgi, _ht40, _s)				\
61	[GROUP_IDX(_streams, _sgi, _ht40)] = {				\
62	.streams = _streams,						\
63	.shift = _s,							\
64	.bw = _ht40,							\
65	.flags =							\
66		IEEE80211_TX_RC_MCS |					\
67		(_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) |			\
68		(_ht40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0),		\
69	.duration = {							\
70		MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26) >> _s,	\
71		MCS_DURATION(_streams, _sgi, _ht40 ? 108 : 52) >> _s,	\
72		MCS_DURATION(_streams, _sgi, _ht40 ? 162 : 78) >> _s,	\
73		MCS_DURATION(_streams, _sgi, _ht40 ? 216 : 104) >> _s,	\
74		MCS_DURATION(_streams, _sgi, _ht40 ? 324 : 156) >> _s,	\
75		MCS_DURATION(_streams, _sgi, _ht40 ? 432 : 208) >> _s,	\
76		MCS_DURATION(_streams, _sgi, _ht40 ? 486 : 234) >> _s,	\
77		MCS_DURATION(_streams, _sgi, _ht40 ? 540 : 260) >> _s	\
78	}								\
79}
80
81#define MCS_GROUP_SHIFT(_streams, _sgi, _ht40)				\
82	GROUP_SHIFT(MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26))
83
84#define MCS_GROUP(_streams, _sgi, _ht40)				\
85	__MCS_GROUP(_streams, _sgi, _ht40,				\
86		    MCS_GROUP_SHIFT(_streams, _sgi, _ht40))
87
88#define VHT_GROUP_IDX(_streams, _sgi, _bw)				\
89	(MINSTREL_VHT_GROUP_0 +						\
90	 MINSTREL_MAX_STREAMS * 2 * (_bw) +				\
91	 MINSTREL_MAX_STREAMS * (_sgi) +				\
92	 (_streams) - 1)
93
94#define BW2VBPS(_bw, r3, r2, r1)					\
95	(_bw == BW_80 ? r3 : _bw == BW_40 ? r2 : r1)
96
97#define __VHT_GROUP(_streams, _sgi, _bw, _s)				\
98	[VHT_GROUP_IDX(_streams, _sgi, _bw)] = {			\
99	.streams = _streams,						\
100	.shift = _s,							\
101	.bw = _bw,							\
102	.flags =							\
103		IEEE80211_TX_RC_VHT_MCS |				\
104		(_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) |			\
105		(_bw == BW_80 ? IEEE80211_TX_RC_80_MHZ_WIDTH :		\
106		 _bw == BW_40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0),	\
107	.duration = {							\
108		MCS_DURATION(_streams, _sgi,				\
109			     BW2VBPS(_bw,  117,  54,  26)) >> _s,	\
110		MCS_DURATION(_streams, _sgi,				\
111			     BW2VBPS(_bw,  234, 108,  52)) >> _s,	\
112		MCS_DURATION(_streams, _sgi,				\
113			     BW2VBPS(_bw,  351, 162,  78)) >> _s,	\
114		MCS_DURATION(_streams, _sgi,				\
115			     BW2VBPS(_bw,  468, 216, 104)) >> _s,	\
116		MCS_DURATION(_streams, _sgi,				\
117			     BW2VBPS(_bw,  702, 324, 156)) >> _s,	\
118		MCS_DURATION(_streams, _sgi,				\
119			     BW2VBPS(_bw,  936, 432, 208)) >> _s,	\
120		MCS_DURATION(_streams, _sgi,				\
121			     BW2VBPS(_bw, 1053, 486, 234)) >> _s,	\
122		MCS_DURATION(_streams, _sgi,				\
123			     BW2VBPS(_bw, 1170, 540, 260)) >> _s,	\
124		MCS_DURATION(_streams, _sgi,				\
125			     BW2VBPS(_bw, 1404, 648, 312)) >> _s,	\
126		MCS_DURATION(_streams, _sgi,				\
127			     BW2VBPS(_bw, 1560, 720, 346)) >> _s	\
128	}								\
129}
130
131#define VHT_GROUP_SHIFT(_streams, _sgi, _bw)				\
132	GROUP_SHIFT(MCS_DURATION(_streams, _sgi,			\
133				 BW2VBPS(_bw,  117,  54,  26)))
134
135#define VHT_GROUP(_streams, _sgi, _bw)					\
136	__VHT_GROUP(_streams, _sgi, _bw,				\
137		    VHT_GROUP_SHIFT(_streams, _sgi, _bw))
138
139#define CCK_DURATION(_bitrate, _short, _len)		\
140	(1000 * (10 /* SIFS */ +			\
141	 (_short ? 72 + 24 : 144 + 48) +		\
142	 (8 * (_len + 4) * 10) / (_bitrate)))
143
144#define CCK_ACK_DURATION(_bitrate, _short)			\
145	(CCK_DURATION((_bitrate > 10 ? 20 : 10), false, 60) +	\
146	 CCK_DURATION(_bitrate, _short, AVG_PKT_SIZE))
147
148#define CCK_DURATION_LIST(_short, _s)			\
149	CCK_ACK_DURATION(10, _short) >> _s,		\
150	CCK_ACK_DURATION(20, _short) >> _s,		\
151	CCK_ACK_DURATION(55, _short) >> _s,		\
152	CCK_ACK_DURATION(110, _short) >> _s
153
154#define __CCK_GROUP(_s)					\
155	[MINSTREL_CCK_GROUP] = {			\
156		.streams = 1,				\
157		.flags = 0,				\
158		.shift = _s,				\
159		.duration = {				\
160			CCK_DURATION_LIST(false, _s),	\
161			CCK_DURATION_LIST(true, _s)	\
162		}					\
163	}
164
165#define CCK_GROUP_SHIFT					\
166	GROUP_SHIFT(CCK_ACK_DURATION(10, false))
167
168#define CCK_GROUP __CCK_GROUP(CCK_GROUP_SHIFT)
169
170
171static bool minstrel_vht_only = true;
172module_param(minstrel_vht_only, bool, 0644);
173MODULE_PARM_DESC(minstrel_vht_only,
174		 "Use only VHT rates when VHT is supported by sta.");
175
176/*
177 * To enable sufficiently targeted rate sampling, MCS rates are divided into
178 * groups, based on the number of streams and flags (HT40, SGI) that they
179 * use.
180 *
181 * Sortorder has to be fixed for GROUP_IDX macro to be applicable:
182 * BW -> SGI -> #streams
183 */
184const struct mcs_group minstrel_mcs_groups[] = {
185	MCS_GROUP(1, 0, BW_20),
186	MCS_GROUP(2, 0, BW_20),
187	MCS_GROUP(3, 0, BW_20),
188	MCS_GROUP(4, 0, BW_20),
189
190	MCS_GROUP(1, 1, BW_20),
191	MCS_GROUP(2, 1, BW_20),
192	MCS_GROUP(3, 1, BW_20),
193	MCS_GROUP(4, 1, BW_20),
194
195	MCS_GROUP(1, 0, BW_40),
196	MCS_GROUP(2, 0, BW_40),
197	MCS_GROUP(3, 0, BW_40),
198	MCS_GROUP(4, 0, BW_40),
199
200	MCS_GROUP(1, 1, BW_40),
201	MCS_GROUP(2, 1, BW_40),
202	MCS_GROUP(3, 1, BW_40),
203	MCS_GROUP(4, 1, BW_40),
204
205	CCK_GROUP,
206
207	VHT_GROUP(1, 0, BW_20),
208	VHT_GROUP(2, 0, BW_20),
209	VHT_GROUP(3, 0, BW_20),
210	VHT_GROUP(4, 0, BW_20),
211
212	VHT_GROUP(1, 1, BW_20),
213	VHT_GROUP(2, 1, BW_20),
214	VHT_GROUP(3, 1, BW_20),
215	VHT_GROUP(4, 1, BW_20),
216
217	VHT_GROUP(1, 0, BW_40),
218	VHT_GROUP(2, 0, BW_40),
219	VHT_GROUP(3, 0, BW_40),
220	VHT_GROUP(4, 0, BW_40),
221
222	VHT_GROUP(1, 1, BW_40),
223	VHT_GROUP(2, 1, BW_40),
224	VHT_GROUP(3, 1, BW_40),
225	VHT_GROUP(4, 1, BW_40),
226
227	VHT_GROUP(1, 0, BW_80),
228	VHT_GROUP(2, 0, BW_80),
229	VHT_GROUP(3, 0, BW_80),
230	VHT_GROUP(4, 0, BW_80),
231
232	VHT_GROUP(1, 1, BW_80),
233	VHT_GROUP(2, 1, BW_80),
234	VHT_GROUP(3, 1, BW_80),
235	VHT_GROUP(4, 1, BW_80),
236};
237
238static u8 sample_table[SAMPLE_COLUMNS][MCS_GROUP_RATES] __read_mostly;
239
240static void
241minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi);
242
243/*
244 * Some VHT MCSes are invalid (when Ndbps / Nes is not an integer)
245 * e.g for MCS9@20MHzx1Nss: Ndbps=8x52*(5/6) Nes=1
246 *
247 * Returns the valid mcs map for struct minstrel_mcs_group_data.supported
248 */
249static u16
250minstrel_get_valid_vht_rates(int bw, int nss, __le16 mcs_map)
251{
252	u16 mask = 0;
253
254	if (bw == BW_20) {
255		if (nss != 3 && nss != 6)
256			mask = BIT(9);
257	} else if (bw == BW_80) {
258		if (nss == 3 || nss == 7)
259			mask = BIT(6);
260		else if (nss == 6)
261			mask = BIT(9);
262	} else {
263		WARN_ON(bw != BW_40);
264	}
265
266	switch ((le16_to_cpu(mcs_map) >> (2 * (nss - 1))) & 3) {
267	case IEEE80211_VHT_MCS_SUPPORT_0_7:
268		mask |= 0x300;
269		break;
270	case IEEE80211_VHT_MCS_SUPPORT_0_8:
271		mask |= 0x200;
272		break;
273	case IEEE80211_VHT_MCS_SUPPORT_0_9:
274		break;
275	default:
276		mask = 0x3ff;
277	}
278
279	return 0x3ff & ~mask;
280}
281
282/*
283 * Look up an MCS group index based on mac80211 rate information
284 */
285static int
286minstrel_ht_get_group_idx(struct ieee80211_tx_rate *rate)
287{
288	return GROUP_IDX((rate->idx / 8) + 1,
289			 !!(rate->flags & IEEE80211_TX_RC_SHORT_GI),
290			 !!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH));
291}
292
293static int
294minstrel_vht_get_group_idx(struct ieee80211_tx_rate *rate)
295{
296	return VHT_GROUP_IDX(ieee80211_rate_get_vht_nss(rate),
297			     !!(rate->flags & IEEE80211_TX_RC_SHORT_GI),
298			     !!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH) +
299			     2*!!(rate->flags & IEEE80211_TX_RC_80_MHZ_WIDTH));
300}
301
302static struct minstrel_rate_stats *
303minstrel_ht_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
304		      struct ieee80211_tx_rate *rate)
305{
306	int group, idx;
307
308	if (rate->flags & IEEE80211_TX_RC_MCS) {
309		group = minstrel_ht_get_group_idx(rate);
310		idx = rate->idx % 8;
311	} else if (rate->flags & IEEE80211_TX_RC_VHT_MCS) {
312		group = minstrel_vht_get_group_idx(rate);
313		idx = ieee80211_rate_get_vht_mcs(rate);
314	} else {
315		group = MINSTREL_CCK_GROUP;
316
317		for (idx = 0; idx < ARRAY_SIZE(mp->cck_rates); idx++)
318			if (rate->idx == mp->cck_rates[idx])
319				break;
320
321		/* short preamble */
322		if ((mi->supported[group] & BIT(idx + 4)) &&
323		    (rate->flags & IEEE80211_TX_RC_USE_SHORT_PREAMBLE))
324			idx += 4;
325	}
326	return &mi->groups[group].rates[idx];
327}
328
329static inline struct minstrel_rate_stats *
330minstrel_get_ratestats(struct minstrel_ht_sta *mi, int index)
331{
332	return &mi->groups[index / MCS_GROUP_RATES].rates[index % MCS_GROUP_RATES];
333}
334
335static unsigned int
336minstrel_ht_avg_ampdu_len(struct minstrel_ht_sta *mi)
337{
338	if (!mi->avg_ampdu_len)
339		return AVG_AMPDU_SIZE;
340
341	return MINSTREL_TRUNC(mi->avg_ampdu_len);
342}
343
344/*
345 * Return current throughput based on the average A-MPDU length, taking into
346 * account the expected number of retransmissions and their expected length
347 */
348int
349minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate,
350		       int prob_avg)
351{
352	unsigned int nsecs = 0;
353
354	/* do not account throughput if sucess prob is below 10% */
355	if (prob_avg < MINSTREL_FRAC(10, 100))
356		return 0;
357
358	if (group != MINSTREL_CCK_GROUP)
359		nsecs = 1000 * mi->overhead / minstrel_ht_avg_ampdu_len(mi);
360
361	nsecs += minstrel_mcs_groups[group].duration[rate] <<
362		 minstrel_mcs_groups[group].shift;
363
364	/*
365	 * For the throughput calculation, limit the probability value to 90% to
366	 * account for collision related packet error rate fluctuation
367	 * (prob is scaled - see MINSTREL_FRAC above)
368	 */
369	if (prob_avg > MINSTREL_FRAC(90, 100))
370		return MINSTREL_TRUNC(100000 * ((MINSTREL_FRAC(90, 100) * 1000)
371								      / nsecs));
372	else
373		return MINSTREL_TRUNC(100000 * ((prob_avg * 1000) / nsecs));
374}
375
376/*
377 * Find & sort topmost throughput rates
378 *
379 * If multiple rates provide equal throughput the sorting is based on their
380 * current success probability. Higher success probability is preferred among
381 * MCS groups, CCK rates do not provide aggregation and are therefore at last.
382 */
383static void
384minstrel_ht_sort_best_tp_rates(struct minstrel_ht_sta *mi, u16 index,
385			       u16 *tp_list)
386{
387	int cur_group, cur_idx, cur_tp_avg, cur_prob;
388	int tmp_group, tmp_idx, tmp_tp_avg, tmp_prob;
389	int j = MAX_THR_RATES;
390
391	cur_group = index / MCS_GROUP_RATES;
392	cur_idx = index  % MCS_GROUP_RATES;
393	cur_prob = mi->groups[cur_group].rates[cur_idx].prob_avg;
394	cur_tp_avg = minstrel_ht_get_tp_avg(mi, cur_group, cur_idx, cur_prob);
395
396	do {
397		tmp_group = tp_list[j - 1] / MCS_GROUP_RATES;
398		tmp_idx = tp_list[j - 1] % MCS_GROUP_RATES;
399		tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg;
400		tmp_tp_avg = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx,
401						    tmp_prob);
402		if (cur_tp_avg < tmp_tp_avg ||
403		    (cur_tp_avg == tmp_tp_avg && cur_prob <= tmp_prob))
404			break;
405		j--;
406	} while (j > 0);
407
408	if (j < MAX_THR_RATES - 1) {
409		memmove(&tp_list[j + 1], &tp_list[j], (sizeof(*tp_list) *
410		       (MAX_THR_RATES - (j + 1))));
411	}
412	if (j < MAX_THR_RATES)
413		tp_list[j] = index;
414}
415
416/*
417 * Find and set the topmost probability rate per sta and per group
418 */
419static void
420minstrel_ht_set_best_prob_rate(struct minstrel_ht_sta *mi, u16 index)
421{
422	struct minstrel_mcs_group_data *mg;
423	struct minstrel_rate_stats *mrs;
424	int tmp_group, tmp_idx, tmp_tp_avg, tmp_prob;
425	int max_tp_group, cur_tp_avg, cur_group, cur_idx;
426	int max_gpr_group, max_gpr_idx;
427	int max_gpr_tp_avg, max_gpr_prob;
428
429	cur_group = index / MCS_GROUP_RATES;
430	cur_idx = index % MCS_GROUP_RATES;
431	mg = &mi->groups[index / MCS_GROUP_RATES];
432	mrs = &mg->rates[index % MCS_GROUP_RATES];
433
434	tmp_group = mi->max_prob_rate / MCS_GROUP_RATES;
435	tmp_idx = mi->max_prob_rate % MCS_GROUP_RATES;
436	tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg;
437	tmp_tp_avg = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob);
438
439	/* if max_tp_rate[0] is from MCS_GROUP max_prob_rate get selected from
440	 * MCS_GROUP as well as CCK_GROUP rates do not allow aggregation */
441	max_tp_group = mi->max_tp_rate[0] / MCS_GROUP_RATES;
442	if((index / MCS_GROUP_RATES == MINSTREL_CCK_GROUP) &&
443	    (max_tp_group != MINSTREL_CCK_GROUP))
444		return;
445
446	max_gpr_group = mg->max_group_prob_rate / MCS_GROUP_RATES;
447	max_gpr_idx = mg->max_group_prob_rate % MCS_GROUP_RATES;
448	max_gpr_prob = mi->groups[max_gpr_group].rates[max_gpr_idx].prob_avg;
449
450	if (mrs->prob_avg > MINSTREL_FRAC(75, 100)) {
451		cur_tp_avg = minstrel_ht_get_tp_avg(mi, cur_group, cur_idx,
452						    mrs->prob_avg);
453		if (cur_tp_avg > tmp_tp_avg)
454			mi->max_prob_rate = index;
455
456		max_gpr_tp_avg = minstrel_ht_get_tp_avg(mi, max_gpr_group,
457							max_gpr_idx,
458							max_gpr_prob);
459		if (cur_tp_avg > max_gpr_tp_avg)
460			mg->max_group_prob_rate = index;
461	} else {
462		if (mrs->prob_avg > tmp_prob)
463			mi->max_prob_rate = index;
464		if (mrs->prob_avg > max_gpr_prob)
465			mg->max_group_prob_rate = index;
466	}
467}
468
469
470/*
471 * Assign new rate set per sta and use CCK rates only if the fastest
472 * rate (max_tp_rate[0]) is from CCK group. This prohibits such sorted
473 * rate sets where MCS and CCK rates are mixed, because CCK rates can
474 * not use aggregation.
475 */
476static void
477minstrel_ht_assign_best_tp_rates(struct minstrel_ht_sta *mi,
478				 u16 tmp_mcs_tp_rate[MAX_THR_RATES],
479				 u16 tmp_cck_tp_rate[MAX_THR_RATES])
480{
481	unsigned int tmp_group, tmp_idx, tmp_cck_tp, tmp_mcs_tp, tmp_prob;
482	int i;
483
484	tmp_group = tmp_cck_tp_rate[0] / MCS_GROUP_RATES;
485	tmp_idx = tmp_cck_tp_rate[0] % MCS_GROUP_RATES;
486	tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg;
487	tmp_cck_tp = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob);
488
489	tmp_group = tmp_mcs_tp_rate[0] / MCS_GROUP_RATES;
490	tmp_idx = tmp_mcs_tp_rate[0] % MCS_GROUP_RATES;
491	tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg;
492	tmp_mcs_tp = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob);
493
494	if (tmp_cck_tp > tmp_mcs_tp) {
495		for(i = 0; i < MAX_THR_RATES; i++) {
496			minstrel_ht_sort_best_tp_rates(mi, tmp_cck_tp_rate[i],
497						       tmp_mcs_tp_rate);
498		}
499	}
500
501}
502
503/*
504 * Try to increase robustness of max_prob rate by decrease number of
505 * streams if possible.
506 */
507static inline void
508minstrel_ht_prob_rate_reduce_streams(struct minstrel_ht_sta *mi)
509{
510	struct minstrel_mcs_group_data *mg;
511	int tmp_max_streams, group, tmp_idx, tmp_prob;
512	int tmp_tp = 0;
513
514	tmp_max_streams = minstrel_mcs_groups[mi->max_tp_rate[0] /
515			  MCS_GROUP_RATES].streams;
516	for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) {
517		mg = &mi->groups[group];
518		if (!mi->supported[group] || group == MINSTREL_CCK_GROUP)
519			continue;
520
521		tmp_idx = mg->max_group_prob_rate % MCS_GROUP_RATES;
522		tmp_prob = mi->groups[group].rates[tmp_idx].prob_avg;
523
524		if (tmp_tp < minstrel_ht_get_tp_avg(mi, group, tmp_idx, tmp_prob) &&
525		   (minstrel_mcs_groups[group].streams < tmp_max_streams)) {
526				mi->max_prob_rate = mg->max_group_prob_rate;
527				tmp_tp = minstrel_ht_get_tp_avg(mi, group,
528								tmp_idx,
529								tmp_prob);
530		}
531	}
532}
533
534static inline int
535minstrel_get_duration(int index)
536{
537	const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
538	unsigned int duration = group->duration[index % MCS_GROUP_RATES];
539	return duration << group->shift;
540}
541
542static bool
543minstrel_ht_probe_group(struct minstrel_ht_sta *mi, const struct mcs_group *tp_group,
544						int tp_idx, const struct mcs_group *group)
545{
546	if (group->bw < tp_group->bw)
547		return false;
548
549	if (group->streams == tp_group->streams)
550		return true;
551
552	if (tp_idx < 4 && group->streams == tp_group->streams - 1)
553		return true;
554
555	return group->streams == tp_group->streams + 1;
556}
557
558static void
559minstrel_ht_find_probe_rates(struct minstrel_ht_sta *mi, u16 *rates, int *n_rates,
560			     bool faster_rate)
561{
562	const struct mcs_group *group, *tp_group;
563	int i, g, max_dur;
564	int tp_idx;
565
566	tp_group = &minstrel_mcs_groups[mi->max_tp_rate[0] / MCS_GROUP_RATES];
567	tp_idx = mi->max_tp_rate[0] % MCS_GROUP_RATES;
568
569	max_dur = minstrel_get_duration(mi->max_tp_rate[0]);
570	if (faster_rate)
571		max_dur -= max_dur / 16;
572
573	for (g = 0; g < MINSTREL_GROUPS_NB; g++) {
574		u16 supported = mi->supported[g];
575
576		if (!supported)
577			continue;
578
579		group = &minstrel_mcs_groups[g];
580		if (!minstrel_ht_probe_group(mi, tp_group, tp_idx, group))
581			continue;
582
583		for (i = 0; supported; supported >>= 1, i++) {
584			int idx;
585
586			if (!(supported & 1))
587				continue;
588
589			if ((group->duration[i] << group->shift) > max_dur)
590				continue;
591
592			idx = g * MCS_GROUP_RATES + i;
593			if (idx == mi->max_tp_rate[0])
594				continue;
595
596			rates[(*n_rates)++] = idx;
597			break;
598		}
599	}
600}
601
602static void
603minstrel_ht_rate_sample_switch(struct minstrel_priv *mp,
604			       struct minstrel_ht_sta *mi)
605{
606	struct minstrel_rate_stats *mrs;
607	u16 rates[MINSTREL_GROUPS_NB];
608	int n_rates = 0;
609	int probe_rate = 0;
610	bool faster_rate;
611	int i;
612	u8 random;
613
614	/*
615	 * Use rate switching instead of probing packets for devices with
616	 * little control over retry fallback behavior
617	 */
618	if (mp->hw->max_rates > 1)
619		return;
620
621	/*
622	 * If the current EWMA prob is >75%, look for a rate that's 6.25%
623	 * faster than the max tp rate.
624	 * If that fails, look again for a rate that is at least as fast
625	 */
626	mrs = minstrel_get_ratestats(mi, mi->max_tp_rate[0]);
627	faster_rate = mrs->prob_avg > MINSTREL_FRAC(75, 100);
628	minstrel_ht_find_probe_rates(mi, rates, &n_rates, faster_rate);
629	if (!n_rates && faster_rate)
630		minstrel_ht_find_probe_rates(mi, rates, &n_rates, false);
631
632	/* If no suitable rate was found, try to pick the next one in the group */
633	if (!n_rates) {
634		int g_idx = mi->max_tp_rate[0] / MCS_GROUP_RATES;
635		u16 supported = mi->supported[g_idx];
636
637		supported >>= mi->max_tp_rate[0] % MCS_GROUP_RATES;
638		for (i = 0; supported; supported >>= 1, i++) {
639			if (!(supported & 1))
640				continue;
641
642			probe_rate = mi->max_tp_rate[0] + i;
643			goto out;
644		}
645
646		return;
647	}
648
649	i = 0;
650	if (n_rates > 1) {
651		random = prandom_u32();
652		i = random % n_rates;
653	}
654	probe_rate = rates[i];
655
656out:
657	mi->sample_rate = probe_rate;
658	mi->sample_mode = MINSTREL_SAMPLE_ACTIVE;
659}
660
661/*
662 * Update rate statistics and select new primary rates
663 *
664 * Rules for rate selection:
665 *  - max_prob_rate must use only one stream, as a tradeoff between delivery
666 *    probability and throughput during strong fluctuations
667 *  - as long as the max prob rate has a probability of more than 75%, pick
668 *    higher throughput rates, even if the probablity is a bit lower
669 */
670static void
671minstrel_ht_update_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
672			 bool sample)
673{
674	struct minstrel_mcs_group_data *mg;
675	struct minstrel_rate_stats *mrs;
676	int group, i, j, cur_prob;
677	u16 tmp_mcs_tp_rate[MAX_THR_RATES], tmp_group_tp_rate[MAX_THR_RATES];
678	u16 tmp_cck_tp_rate[MAX_THR_RATES], index;
679
680	mi->sample_mode = MINSTREL_SAMPLE_IDLE;
681
682	if (sample) {
683		mi->total_packets_cur = mi->total_packets -
684					mi->total_packets_last;
685		mi->total_packets_last = mi->total_packets;
686	}
687	if (!mp->sample_switch)
688		sample = false;
689	if (mi->total_packets_cur < SAMPLE_SWITCH_THR && mp->sample_switch != 1)
690	    sample = false;
691
692	if (mi->ampdu_packets > 0) {
693		if (!ieee80211_hw_check(mp->hw, TX_STATUS_NO_AMPDU_LEN))
694			mi->avg_ampdu_len = minstrel_ewma(mi->avg_ampdu_len,
695				MINSTREL_FRAC(mi->ampdu_len, mi->ampdu_packets),
696					      EWMA_LEVEL);
697		else
698			mi->avg_ampdu_len = 0;
699		mi->ampdu_len = 0;
700		mi->ampdu_packets = 0;
701	}
702
703	mi->sample_slow = 0;
704	mi->sample_count = 0;
705
706	memset(tmp_mcs_tp_rate, 0, sizeof(tmp_mcs_tp_rate));
707	memset(tmp_cck_tp_rate, 0, sizeof(tmp_cck_tp_rate));
708	if (mi->supported[MINSTREL_CCK_GROUP])
709		for (j = 0; j < ARRAY_SIZE(tmp_cck_tp_rate); j++)
710			tmp_cck_tp_rate[j] = MINSTREL_CCK_GROUP * MCS_GROUP_RATES;
711
712	if (mi->supported[MINSTREL_VHT_GROUP_0])
713		index = MINSTREL_VHT_GROUP_0 * MCS_GROUP_RATES;
714	else
715		index = MINSTREL_HT_GROUP_0 * MCS_GROUP_RATES;
716
717	for (j = 0; j < ARRAY_SIZE(tmp_mcs_tp_rate); j++)
718		tmp_mcs_tp_rate[j] = index;
719
720	/* Find best rate sets within all MCS groups*/
721	for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) {
722
723		mg = &mi->groups[group];
724		if (!mi->supported[group])
725			continue;
726
727		mi->sample_count++;
728
729		/* (re)Initialize group rate indexes */
730		for(j = 0; j < MAX_THR_RATES; j++)
731			tmp_group_tp_rate[j] = MCS_GROUP_RATES * group;
732
733		for (i = 0; i < MCS_GROUP_RATES; i++) {
734			if (!(mi->supported[group] & BIT(i)))
735				continue;
736
737			index = MCS_GROUP_RATES * group + i;
738
739			mrs = &mg->rates[i];
740			mrs->retry_updated = false;
741			minstrel_calc_rate_stats(mp, mrs);
742			cur_prob = mrs->prob_avg;
743
744			if (minstrel_ht_get_tp_avg(mi, group, i, cur_prob) == 0)
745				continue;
746
747			/* Find max throughput rate set */
748			if (group != MINSTREL_CCK_GROUP) {
749				minstrel_ht_sort_best_tp_rates(mi, index,
750							       tmp_mcs_tp_rate);
751			} else if (group == MINSTREL_CCK_GROUP) {
752				minstrel_ht_sort_best_tp_rates(mi, index,
753							       tmp_cck_tp_rate);
754			}
755
756			/* Find max throughput rate set within a group */
757			minstrel_ht_sort_best_tp_rates(mi, index,
758						       tmp_group_tp_rate);
759
760			/* Find max probability rate per group and global */
761			minstrel_ht_set_best_prob_rate(mi, index);
762		}
763
764		memcpy(mg->max_group_tp_rate, tmp_group_tp_rate,
765		       sizeof(mg->max_group_tp_rate));
766	}
767
768	/* Assign new rate set per sta */
769	minstrel_ht_assign_best_tp_rates(mi, tmp_mcs_tp_rate, tmp_cck_tp_rate);
770	memcpy(mi->max_tp_rate, tmp_mcs_tp_rate, sizeof(mi->max_tp_rate));
771
772	/* Try to increase robustness of max_prob_rate*/
773	minstrel_ht_prob_rate_reduce_streams(mi);
774
775	/* try to sample all available rates during each interval */
776	mi->sample_count *= 8;
777	if (mp->new_avg)
778		mi->sample_count /= 2;
779
780	if (sample)
781		minstrel_ht_rate_sample_switch(mp, mi);
782
783#ifdef CONFIG_MAC80211_DEBUGFS
784	/* use fixed index if set */
785	if (mp->fixed_rate_idx != -1) {
786		for (i = 0; i < 4; i++)
787			mi->max_tp_rate[i] = mp->fixed_rate_idx;
788		mi->max_prob_rate = mp->fixed_rate_idx;
789		mi->sample_mode = MINSTREL_SAMPLE_IDLE;
790	}
791#endif
792
793	/* Reset update timer */
794	mi->last_stats_update = jiffies;
795}
796
797static bool
798minstrel_ht_txstat_valid(struct minstrel_priv *mp, struct ieee80211_tx_rate *rate)
799{
800	if (rate->idx < 0)
801		return false;
802
803	if (!rate->count)
804		return false;
805
806	if (rate->flags & IEEE80211_TX_RC_MCS ||
807	    rate->flags & IEEE80211_TX_RC_VHT_MCS)
808		return true;
809
810	return rate->idx == mp->cck_rates[0] ||
811	       rate->idx == mp->cck_rates[1] ||
812	       rate->idx == mp->cck_rates[2] ||
813	       rate->idx == mp->cck_rates[3];
814}
815
816static void
817minstrel_set_next_sample_idx(struct minstrel_ht_sta *mi)
818{
819	struct minstrel_mcs_group_data *mg;
820
821	for (;;) {
822		mi->sample_group++;
823		mi->sample_group %= ARRAY_SIZE(minstrel_mcs_groups);
824		mg = &mi->groups[mi->sample_group];
825
826		if (!mi->supported[mi->sample_group])
827			continue;
828
829		if (++mg->index >= MCS_GROUP_RATES) {
830			mg->index = 0;
831			if (++mg->column >= ARRAY_SIZE(sample_table))
832				mg->column = 0;
833		}
834		break;
835	}
836}
837
838static void
839minstrel_downgrade_rate(struct minstrel_ht_sta *mi, u16 *idx, bool primary)
840{
841	int group, orig_group;
842
843	orig_group = group = *idx / MCS_GROUP_RATES;
844	while (group > 0) {
845		group--;
846
847		if (!mi->supported[group])
848			continue;
849
850		if (minstrel_mcs_groups[group].streams >
851		    minstrel_mcs_groups[orig_group].streams)
852			continue;
853
854		if (primary)
855			*idx = mi->groups[group].max_group_tp_rate[0];
856		else
857			*idx = mi->groups[group].max_group_tp_rate[1];
858		break;
859	}
860}
861
862static void
863minstrel_aggr_check(struct ieee80211_sta *pubsta, struct sk_buff *skb)
864{
865	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data;
866	struct sta_info *sta = container_of(pubsta, struct sta_info, sta);
867	u16 tid;
868
869	if (skb_get_queue_mapping(skb) == IEEE80211_AC_VO)
870		return;
871
872	if (unlikely(!ieee80211_is_data_qos(hdr->frame_control)))
873		return;
874
875	if (unlikely(skb->protocol == cpu_to_be16(ETH_P_PAE)))
876		return;
877
878	tid = ieee80211_get_tid(hdr);
879	if (likely(sta->ampdu_mlme.tid_tx[tid]))
880		return;
881
882	ieee80211_start_tx_ba_session(pubsta, tid, 0);
883}
884
885static void
886minstrel_ht_tx_status(void *priv, struct ieee80211_supported_band *sband,
887                      void *priv_sta, struct ieee80211_tx_status *st)
888{
889	struct ieee80211_tx_info *info = st->info;
890	struct minstrel_ht_sta_priv *msp = priv_sta;
891	struct minstrel_ht_sta *mi = &msp->ht;
892	struct ieee80211_tx_rate *ar = info->status.rates;
893	struct minstrel_rate_stats *rate, *rate2, *rate_sample = NULL;
894	struct minstrel_priv *mp = priv;
895	u32 update_interval = mp->update_interval / 2;
896	bool last, update = false;
897	bool sample_status = false;
898	int i;
899
900	if (!msp->is_ht)
901		return mac80211_minstrel.tx_status_ext(priv, sband,
902						       &msp->legacy, st);
903
904
905	/* This packet was aggregated but doesn't carry status info */
906	if ((info->flags & IEEE80211_TX_CTL_AMPDU) &&
907	    !(info->flags & IEEE80211_TX_STAT_AMPDU))
908		return;
909
910	if (!(info->flags & IEEE80211_TX_STAT_AMPDU)) {
911		info->status.ampdu_ack_len =
912			(info->flags & IEEE80211_TX_STAT_ACK ? 1 : 0);
913		info->status.ampdu_len = 1;
914	}
915
916	mi->ampdu_packets++;
917	mi->ampdu_len += info->status.ampdu_len;
918
919	if (!mi->sample_wait && !mi->sample_tries && mi->sample_count > 0) {
920		int avg_ampdu_len = minstrel_ht_avg_ampdu_len(mi);
921
922		mi->sample_wait = 16 + 2 * avg_ampdu_len;
923		mi->sample_tries = 1;
924		mi->sample_count--;
925	}
926
927	if (info->flags & IEEE80211_TX_CTL_RATE_CTRL_PROBE)
928		mi->sample_packets += info->status.ampdu_len;
929
930	if (mi->sample_mode != MINSTREL_SAMPLE_IDLE)
931		rate_sample = minstrel_get_ratestats(mi, mi->sample_rate);
932
933	last = !minstrel_ht_txstat_valid(mp, &ar[0]);
934	for (i = 0; !last; i++) {
935		last = (i == IEEE80211_TX_MAX_RATES - 1) ||
936		       !minstrel_ht_txstat_valid(mp, &ar[i + 1]);
937
938		rate = minstrel_ht_get_stats(mp, mi, &ar[i]);
939		if (rate == rate_sample)
940			sample_status = true;
941
942		if (last)
943			rate->success += info->status.ampdu_ack_len;
944
945		rate->attempts += ar[i].count * info->status.ampdu_len;
946	}
947
948	switch (mi->sample_mode) {
949	case MINSTREL_SAMPLE_IDLE:
950		if (mp->new_avg &&
951		    (mp->hw->max_rates > 1 ||
952		     mi->total_packets_cur < SAMPLE_SWITCH_THR))
953			update_interval /= 2;
954		break;
955
956	case MINSTREL_SAMPLE_ACTIVE:
957		if (!sample_status)
958			break;
959
960		mi->sample_mode = MINSTREL_SAMPLE_PENDING;
961		update = true;
962		break;
963
964	case MINSTREL_SAMPLE_PENDING:
965		if (sample_status)
966			break;
967
968		update = true;
969		minstrel_ht_update_stats(mp, mi, false);
970		break;
971	}
972
973
974	if (mp->hw->max_rates > 1) {
975		/*
976		 * check for sudden death of spatial multiplexing,
977		 * downgrade to a lower number of streams if necessary.
978		 */
979		rate = minstrel_get_ratestats(mi, mi->max_tp_rate[0]);
980		if (rate->attempts > 30 &&
981		    rate->success < rate->attempts / 4) {
982			minstrel_downgrade_rate(mi, &mi->max_tp_rate[0], true);
983			update = true;
984		}
985
986		rate2 = minstrel_get_ratestats(mi, mi->max_tp_rate[1]);
987		if (rate2->attempts > 30 &&
988		    rate2->success < rate2->attempts / 4) {
989			minstrel_downgrade_rate(mi, &mi->max_tp_rate[1], false);
990			update = true;
991		}
992	}
993
994	if (time_after(jiffies, mi->last_stats_update + update_interval)) {
995		update = true;
996		minstrel_ht_update_stats(mp, mi, true);
997	}
998
999	if (update)
1000		minstrel_ht_update_rates(mp, mi);
1001}
1002
1003static void
1004minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
1005                         int index)
1006{
1007	struct minstrel_rate_stats *mrs;
1008	unsigned int tx_time, tx_time_rtscts, tx_time_data;
1009	unsigned int cw = mp->cw_min;
1010	unsigned int ctime = 0;
1011	unsigned int t_slot = 9; /* FIXME */
1012	unsigned int ampdu_len = minstrel_ht_avg_ampdu_len(mi);
1013	unsigned int overhead = 0, overhead_rtscts = 0;
1014
1015	mrs = minstrel_get_ratestats(mi, index);
1016	if (mrs->prob_avg < MINSTREL_FRAC(1, 10)) {
1017		mrs->retry_count = 1;
1018		mrs->retry_count_rtscts = 1;
1019		return;
1020	}
1021
1022	mrs->retry_count = 2;
1023	mrs->retry_count_rtscts = 2;
1024	mrs->retry_updated = true;
1025
1026	tx_time_data = minstrel_get_duration(index) * ampdu_len / 1000;
1027
1028	/* Contention time for first 2 tries */
1029	ctime = (t_slot * cw) >> 1;
1030	cw = min((cw << 1) | 1, mp->cw_max);
1031	ctime += (t_slot * cw) >> 1;
1032	cw = min((cw << 1) | 1, mp->cw_max);
1033
1034	if (index / MCS_GROUP_RATES != MINSTREL_CCK_GROUP) {
1035		overhead = mi->overhead;
1036		overhead_rtscts = mi->overhead_rtscts;
1037	}
1038
1039	/* Total TX time for data and Contention after first 2 tries */
1040	tx_time = ctime + 2 * (overhead + tx_time_data);
1041	tx_time_rtscts = ctime + 2 * (overhead_rtscts + tx_time_data);
1042
1043	/* See how many more tries we can fit inside segment size */
1044	do {
1045		/* Contention time for this try */
1046		ctime = (t_slot * cw) >> 1;
1047		cw = min((cw << 1) | 1, mp->cw_max);
1048
1049		/* Total TX time after this try */
1050		tx_time += ctime + overhead + tx_time_data;
1051		tx_time_rtscts += ctime + overhead_rtscts + tx_time_data;
1052
1053		if (tx_time_rtscts < mp->segment_size)
1054			mrs->retry_count_rtscts++;
1055	} while ((tx_time < mp->segment_size) &&
1056	         (++mrs->retry_count < mp->max_retry));
1057}
1058
1059
1060static void
1061minstrel_ht_set_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
1062                     struct ieee80211_sta_rates *ratetbl, int offset, int index)
1063{
1064	const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
1065	struct minstrel_rate_stats *mrs;
1066	u8 idx;
1067	u16 flags = group->flags;
1068
1069	mrs = minstrel_get_ratestats(mi, index);
1070	if (!mrs->retry_updated)
1071		minstrel_calc_retransmit(mp, mi, index);
1072
1073	if (mrs->prob_avg < MINSTREL_FRAC(20, 100) || !mrs->retry_count) {
1074		ratetbl->rate[offset].count = 2;
1075		ratetbl->rate[offset].count_rts = 2;
1076		ratetbl->rate[offset].count_cts = 2;
1077	} else {
1078		ratetbl->rate[offset].count = mrs->retry_count;
1079		ratetbl->rate[offset].count_cts = mrs->retry_count;
1080		ratetbl->rate[offset].count_rts = mrs->retry_count_rtscts;
1081	}
1082
1083	if (index / MCS_GROUP_RATES == MINSTREL_CCK_GROUP)
1084		idx = mp->cck_rates[index % ARRAY_SIZE(mp->cck_rates)];
1085	else if (flags & IEEE80211_TX_RC_VHT_MCS)
1086		idx = ((group->streams - 1) << 4) |
1087		      ((index % MCS_GROUP_RATES) & 0xF);
1088	else
1089		idx = index % MCS_GROUP_RATES + (group->streams - 1) * 8;
1090
1091	/* enable RTS/CTS if needed:
1092	 *  - if station is in dynamic SMPS (and streams > 1)
1093	 *  - for fallback rates, to increase chances of getting through
1094	 */
1095	if (offset > 0 ||
1096	    (mi->sta->smps_mode == IEEE80211_SMPS_DYNAMIC &&
1097	     group->streams > 1)) {
1098		ratetbl->rate[offset].count = ratetbl->rate[offset].count_rts;
1099		flags |= IEEE80211_TX_RC_USE_RTS_CTS;
1100	}
1101
1102	ratetbl->rate[offset].idx = idx;
1103	ratetbl->rate[offset].flags = flags;
1104}
1105
1106static inline int
1107minstrel_ht_get_prob_avg(struct minstrel_ht_sta *mi, int rate)
1108{
1109	int group = rate / MCS_GROUP_RATES;
1110	rate %= MCS_GROUP_RATES;
1111	return mi->groups[group].rates[rate].prob_avg;
1112}
1113
1114static int
1115minstrel_ht_get_max_amsdu_len(struct minstrel_ht_sta *mi)
1116{
1117	int group = mi->max_prob_rate / MCS_GROUP_RATES;
1118	const struct mcs_group *g = &minstrel_mcs_groups[group];
1119	int rate = mi->max_prob_rate % MCS_GROUP_RATES;
1120	unsigned int duration;
1121
1122	/* Disable A-MSDU if max_prob_rate is bad */
1123	if (mi->groups[group].rates[rate].prob_avg < MINSTREL_FRAC(50, 100))
1124		return 1;
1125
1126	duration = g->duration[rate];
1127	duration <<= g->shift;
1128
1129	/* If the rate is slower than single-stream MCS1, make A-MSDU limit small */
1130	if (duration > MCS_DURATION(1, 0, 52))
1131		return 500;
1132
1133	/*
1134	 * If the rate is slower than single-stream MCS4, limit A-MSDU to usual
1135	 * data packet size
1136	 */
1137	if (duration > MCS_DURATION(1, 0, 104))
1138		return 1600;
1139
1140	/*
1141	 * If the rate is slower than single-stream MCS7, or if the max throughput
1142	 * rate success probability is less than 75%, limit A-MSDU to twice the usual
1143	 * data packet size
1144	 */
1145	if (duration > MCS_DURATION(1, 0, 260) ||
1146	    (minstrel_ht_get_prob_avg(mi, mi->max_tp_rate[0]) <
1147	     MINSTREL_FRAC(75, 100)))
1148		return 3200;
1149
1150	/*
1151	 * HT A-MPDU limits maximum MPDU size under BA agreement to 4095 bytes.
1152	 * Since aggregation sessions are started/stopped without txq flush, use
1153	 * the limit here to avoid the complexity of having to de-aggregate
1154	 * packets in the queue.
1155	 */
1156	if (!mi->sta->vht_cap.vht_supported)
1157		return IEEE80211_MAX_MPDU_LEN_HT_BA;
1158
1159	/* unlimited */
1160	return 0;
1161}
1162
1163static void
1164minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi)
1165{
1166	struct ieee80211_sta_rates *rates;
1167	u16 first_rate = mi->max_tp_rate[0];
1168	int i = 0;
1169
1170	if (mi->sample_mode == MINSTREL_SAMPLE_ACTIVE)
1171		first_rate = mi->sample_rate;
1172
1173	rates = kzalloc(sizeof(*rates), GFP_ATOMIC);
1174	if (!rates)
1175		return;
1176
1177	/* Start with max_tp_rate[0] */
1178	minstrel_ht_set_rate(mp, mi, rates, i++, first_rate);
1179
1180	if (mp->hw->max_rates >= 3) {
1181		/* At least 3 tx rates supported, use max_tp_rate[1] next */
1182		minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_tp_rate[1]);
1183	}
1184
1185	if (mp->hw->max_rates >= 2) {
1186		minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_prob_rate);
1187	}
1188
1189	mi->sta->max_rc_amsdu_len = minstrel_ht_get_max_amsdu_len(mi);
1190	rates->rate[i].idx = -1;
1191	rate_control_set_rates(mp->hw, mi->sta, rates);
1192}
1193
1194static int
1195minstrel_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi)
1196{
1197	struct minstrel_rate_stats *mrs;
1198	struct minstrel_mcs_group_data *mg;
1199	unsigned int sample_dur, sample_group, cur_max_tp_streams;
1200	int tp_rate1, tp_rate2;
1201	int sample_idx = 0;
1202
1203	if (mp->hw->max_rates == 1 && mp->sample_switch &&
1204	    (mi->total_packets_cur >= SAMPLE_SWITCH_THR ||
1205	     mp->sample_switch == 1))
1206		return -1;
1207
1208	if (mi->sample_wait > 0) {
1209		mi->sample_wait--;
1210		return -1;
1211	}
1212
1213	if (!mi->sample_tries)
1214		return -1;
1215
1216	sample_group = mi->sample_group;
1217	mg = &mi->groups[sample_group];
1218	sample_idx = sample_table[mg->column][mg->index];
1219	minstrel_set_next_sample_idx(mi);
1220
1221	if (!(mi->supported[sample_group] & BIT(sample_idx)))
1222		return -1;
1223
1224	mrs = &mg->rates[sample_idx];
1225	sample_idx += sample_group * MCS_GROUP_RATES;
1226
1227	/* Set tp_rate1, tp_rate2 to the highest / second highest max_tp_rate */
1228	if (minstrel_get_duration(mi->max_tp_rate[0]) >
1229	    minstrel_get_duration(mi->max_tp_rate[1])) {
1230		tp_rate1 = mi->max_tp_rate[1];
1231		tp_rate2 = mi->max_tp_rate[0];
1232	} else {
1233		tp_rate1 = mi->max_tp_rate[0];
1234		tp_rate2 = mi->max_tp_rate[1];
1235	}
1236
1237	/*
1238	 * Sampling might add some overhead (RTS, no aggregation)
1239	 * to the frame. Hence, don't use sampling for the highest currently
1240	 * used highest throughput or probability rate.
1241	 */
1242	if (sample_idx == mi->max_tp_rate[0] || sample_idx == mi->max_prob_rate)
1243		return -1;
1244
1245	/*
1246	 * Do not sample if the probability is already higher than 95%,
1247	 * or if the rate is 3 times slower than the current max probability
1248	 * rate, to avoid wasting airtime.
1249	 */
1250	sample_dur = minstrel_get_duration(sample_idx);
1251	if (mrs->prob_avg > MINSTREL_FRAC(95, 100) ||
1252	    minstrel_get_duration(mi->max_prob_rate) * 3 < sample_dur)
1253		return -1;
1254
1255
1256	/*
1257	 * For devices with no configurable multi-rate retry, skip sampling
1258	 * below the per-group max throughput rate, and only use one sampling
1259	 * attempt per rate
1260	 */
1261	if (mp->hw->max_rates == 1 &&
1262	    (minstrel_get_duration(mg->max_group_tp_rate[0]) < sample_dur ||
1263	     mrs->attempts))
1264		return -1;
1265
1266	/* Skip already sampled slow rates */
1267	if (sample_dur >= minstrel_get_duration(tp_rate1) && mrs->attempts)
1268		return -1;
1269
1270	/*
1271	 * Make sure that lower rates get sampled only occasionally,
1272	 * if the link is working perfectly.
1273	 */
1274
1275	cur_max_tp_streams = minstrel_mcs_groups[tp_rate1 /
1276		MCS_GROUP_RATES].streams;
1277	if (sample_dur >= minstrel_get_duration(tp_rate2) &&
1278	    (cur_max_tp_streams - 1 <
1279	     minstrel_mcs_groups[sample_group].streams ||
1280	     sample_dur >= minstrel_get_duration(mi->max_prob_rate))) {
1281		if (mrs->sample_skipped < 20)
1282			return -1;
1283
1284		if (mi->sample_slow++ > 2)
1285			return -1;
1286	}
1287	mi->sample_tries--;
1288
1289	return sample_idx;
1290}
1291
1292static void
1293minstrel_ht_get_rate(void *priv, struct ieee80211_sta *sta, void *priv_sta,
1294                     struct ieee80211_tx_rate_control *txrc)
1295{
1296	const struct mcs_group *sample_group;
1297	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
1298	struct ieee80211_tx_rate *rate = &info->status.rates[0];
1299	struct minstrel_ht_sta_priv *msp = priv_sta;
1300	struct minstrel_ht_sta *mi = &msp->ht;
1301	struct minstrel_priv *mp = priv;
1302	int sample_idx;
1303
1304	if (!msp->is_ht)
1305		return mac80211_minstrel.get_rate(priv, sta, &msp->legacy, txrc);
1306
1307	if (!(info->flags & IEEE80211_TX_CTL_AMPDU) &&
1308	    mi->max_prob_rate / MCS_GROUP_RATES != MINSTREL_CCK_GROUP)
1309		minstrel_aggr_check(sta, txrc->skb);
1310
1311	info->flags |= mi->tx_flags;
1312
1313#ifdef CONFIG_MAC80211_DEBUGFS
1314	if (mp->fixed_rate_idx != -1)
1315		return;
1316#endif
1317
1318	/* Don't use EAPOL frames for sampling on non-mrr hw */
1319	if (mp->hw->max_rates == 1 &&
1320	    (info->control.flags & IEEE80211_TX_CTRL_PORT_CTRL_PROTO))
1321		sample_idx = -1;
1322	else
1323		sample_idx = minstrel_get_sample_rate(mp, mi);
1324
1325	mi->total_packets++;
1326
1327	/* wraparound */
1328	if (mi->total_packets == ~0) {
1329		mi->total_packets = 0;
1330		mi->sample_packets = 0;
1331	}
1332
1333	if (sample_idx < 0)
1334		return;
1335
1336	sample_group = &minstrel_mcs_groups[sample_idx / MCS_GROUP_RATES];
1337	sample_idx %= MCS_GROUP_RATES;
1338
1339	if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP] &&
1340	    (sample_idx >= 4) != txrc->short_preamble)
1341		return;
1342
1343	info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE;
1344	rate->count = 1;
1345
1346	if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP]) {
1347		int idx = sample_idx % ARRAY_SIZE(mp->cck_rates);
1348		rate->idx = mp->cck_rates[idx];
1349	} else if (sample_group->flags & IEEE80211_TX_RC_VHT_MCS) {
1350		ieee80211_rate_set_vht(rate, sample_idx % MCS_GROUP_RATES,
1351				       sample_group->streams);
1352	} else {
1353		rate->idx = sample_idx + (sample_group->streams - 1) * 8;
1354	}
1355
1356	rate->flags = sample_group->flags;
1357}
1358
1359static void
1360minstrel_ht_update_cck(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
1361		       struct ieee80211_supported_band *sband,
1362		       struct ieee80211_sta *sta)
1363{
1364	int i;
1365
1366	if (sband->band != NL80211_BAND_2GHZ)
1367		return;
1368
1369	if (!ieee80211_hw_check(mp->hw, SUPPORTS_HT_CCK_RATES))
1370		return;
1371
1372	mi->cck_supported = 0;
1373	mi->cck_supported_short = 0;
1374	for (i = 0; i < 4; i++) {
1375		if (!rate_supported(sta, sband->band, mp->cck_rates[i]))
1376			continue;
1377
1378		mi->cck_supported |= BIT(i);
1379		if (sband->bitrates[i].flags & IEEE80211_RATE_SHORT_PREAMBLE)
1380			mi->cck_supported_short |= BIT(i);
1381	}
1382
1383	mi->supported[MINSTREL_CCK_GROUP] = mi->cck_supported;
1384}
1385
1386static void
1387minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband,
1388			struct cfg80211_chan_def *chandef,
1389                        struct ieee80211_sta *sta, void *priv_sta)
1390{
1391	struct minstrel_priv *mp = priv;
1392	struct minstrel_ht_sta_priv *msp = priv_sta;
1393	struct minstrel_ht_sta *mi = &msp->ht;
1394	struct ieee80211_mcs_info *mcs = &sta->ht_cap.mcs;
1395	u16 ht_cap = sta->ht_cap.cap;
1396	struct ieee80211_sta_vht_cap *vht_cap = &sta->vht_cap;
1397	int use_vht;
1398	int n_supported = 0;
1399	int ack_dur;
1400	int stbc;
1401	int i;
1402	bool ldpc;
1403
1404	/* fall back to the old minstrel for legacy stations */
1405	if (!sta->ht_cap.ht_supported)
1406		goto use_legacy;
1407
1408	BUILD_BUG_ON(ARRAY_SIZE(minstrel_mcs_groups) != MINSTREL_GROUPS_NB);
1409
1410	if (vht_cap->vht_supported)
1411		use_vht = vht_cap->vht_mcs.tx_mcs_map != cpu_to_le16(~0);
1412	else
1413		use_vht = 0;
1414
1415	msp->is_ht = true;
1416	memset(mi, 0, sizeof(*mi));
1417
1418	mi->sta = sta;
1419	mi->last_stats_update = jiffies;
1420
1421	ack_dur = ieee80211_frame_duration(sband->band, 10, 60, 1, 1, 0);
1422	mi->overhead = ieee80211_frame_duration(sband->band, 0, 60, 1, 1, 0);
1423	mi->overhead += ack_dur;
1424	mi->overhead_rtscts = mi->overhead + 2 * ack_dur;
1425
1426	mi->avg_ampdu_len = MINSTREL_FRAC(1, 1);
1427
1428	/* When using MRR, sample more on the first attempt, without delay */
1429	if (mp->has_mrr) {
1430		mi->sample_count = 16;
1431		mi->sample_wait = 0;
1432	} else {
1433		mi->sample_count = 8;
1434		mi->sample_wait = 8;
1435	}
1436	mi->sample_tries = 4;
1437
1438	if (!use_vht) {
1439		stbc = (ht_cap & IEEE80211_HT_CAP_RX_STBC) >>
1440			IEEE80211_HT_CAP_RX_STBC_SHIFT;
1441
1442		ldpc = ht_cap & IEEE80211_HT_CAP_LDPC_CODING;
1443	} else {
1444		stbc = (vht_cap->cap & IEEE80211_VHT_CAP_RXSTBC_MASK) >>
1445			IEEE80211_VHT_CAP_RXSTBC_SHIFT;
1446
1447		ldpc = vht_cap->cap & IEEE80211_VHT_CAP_RXLDPC;
1448	}
1449
1450	mi->tx_flags |= stbc << IEEE80211_TX_CTL_STBC_SHIFT;
1451	if (ldpc)
1452		mi->tx_flags |= IEEE80211_TX_CTL_LDPC;
1453
1454	for (i = 0; i < ARRAY_SIZE(mi->groups); i++) {
1455		u32 gflags = minstrel_mcs_groups[i].flags;
1456		int bw, nss;
1457
1458		mi->supported[i] = 0;
1459		if (i == MINSTREL_CCK_GROUP) {
1460			minstrel_ht_update_cck(mp, mi, sband, sta);
1461			continue;
1462		}
1463
1464		if (gflags & IEEE80211_TX_RC_SHORT_GI) {
1465			if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) {
1466				if (!(ht_cap & IEEE80211_HT_CAP_SGI_40))
1467					continue;
1468			} else {
1469				if (!(ht_cap & IEEE80211_HT_CAP_SGI_20))
1470					continue;
1471			}
1472		}
1473
1474		if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH &&
1475		    sta->bandwidth < IEEE80211_STA_RX_BW_40)
1476			continue;
1477
1478		nss = minstrel_mcs_groups[i].streams;
1479
1480		/* Mark MCS > 7 as unsupported if STA is in static SMPS mode */
1481		if (sta->smps_mode == IEEE80211_SMPS_STATIC && nss > 1)
1482			continue;
1483
1484		/* HT rate */
1485		if (gflags & IEEE80211_TX_RC_MCS) {
1486			if (use_vht && minstrel_vht_only)
1487				continue;
1488
1489			mi->supported[i] = mcs->rx_mask[nss - 1];
1490			if (mi->supported[i])
1491				n_supported++;
1492			continue;
1493		}
1494
1495		/* VHT rate */
1496		if (!vht_cap->vht_supported ||
1497		    WARN_ON(!(gflags & IEEE80211_TX_RC_VHT_MCS)) ||
1498		    WARN_ON(gflags & IEEE80211_TX_RC_160_MHZ_WIDTH))
1499			continue;
1500
1501		if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH) {
1502			if (sta->bandwidth < IEEE80211_STA_RX_BW_80 ||
1503			    ((gflags & IEEE80211_TX_RC_SHORT_GI) &&
1504			     !(vht_cap->cap & IEEE80211_VHT_CAP_SHORT_GI_80))) {
1505				continue;
1506			}
1507		}
1508
1509		if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH)
1510			bw = BW_40;
1511		else if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH)
1512			bw = BW_80;
1513		else
1514			bw = BW_20;
1515
1516		mi->supported[i] = minstrel_get_valid_vht_rates(bw, nss,
1517				vht_cap->vht_mcs.tx_mcs_map);
1518
1519		if (mi->supported[i])
1520			n_supported++;
1521	}
1522
1523	if (!n_supported)
1524		goto use_legacy;
1525
1526	mi->supported[MINSTREL_CCK_GROUP] |= mi->cck_supported_short << 4;
1527
1528	/* create an initial rate table with the lowest supported rates */
1529	minstrel_ht_update_stats(mp, mi, true);
1530	minstrel_ht_update_rates(mp, mi);
1531
1532	return;
1533
1534use_legacy:
1535	msp->is_ht = false;
1536	memset(&msp->legacy, 0, sizeof(msp->legacy));
1537	msp->legacy.r = msp->ratelist;
1538	msp->legacy.sample_table = msp->sample_table;
1539	return mac80211_minstrel.rate_init(priv, sband, chandef, sta,
1540					   &msp->legacy);
1541}
1542
1543static void
1544minstrel_ht_rate_init(void *priv, struct ieee80211_supported_band *sband,
1545		      struct cfg80211_chan_def *chandef,
1546                      struct ieee80211_sta *sta, void *priv_sta)
1547{
1548	minstrel_ht_update_caps(priv, sband, chandef, sta, priv_sta);
1549}
1550
1551static void
1552minstrel_ht_rate_update(void *priv, struct ieee80211_supported_band *sband,
1553			struct cfg80211_chan_def *chandef,
1554                        struct ieee80211_sta *sta, void *priv_sta,
1555                        u32 changed)
1556{
1557	minstrel_ht_update_caps(priv, sband, chandef, sta, priv_sta);
1558}
1559
1560static void *
1561minstrel_ht_alloc_sta(void *priv, struct ieee80211_sta *sta, gfp_t gfp)
1562{
1563	struct ieee80211_supported_band *sband;
1564	struct minstrel_ht_sta_priv *msp;
1565	struct minstrel_priv *mp = priv;
1566	struct ieee80211_hw *hw = mp->hw;
1567	int max_rates = 0;
1568	int i;
1569
1570	for (i = 0; i < NUM_NL80211_BANDS; i++) {
1571		sband = hw->wiphy->bands[i];
1572		if (sband && sband->n_bitrates > max_rates)
1573			max_rates = sband->n_bitrates;
1574	}
1575
1576	msp = kzalloc(sizeof(*msp), gfp);
1577	if (!msp)
1578		return NULL;
1579
1580	msp->ratelist = kcalloc(max_rates, sizeof(struct minstrel_rate), gfp);
1581	if (!msp->ratelist)
1582		goto error;
1583
1584	msp->sample_table = kmalloc_array(max_rates, SAMPLE_COLUMNS, gfp);
1585	if (!msp->sample_table)
1586		goto error1;
1587
1588	return msp;
1589
1590error1:
1591	kfree(msp->ratelist);
1592error:
1593	kfree(msp);
1594	return NULL;
1595}
1596
1597static void
1598minstrel_ht_free_sta(void *priv, struct ieee80211_sta *sta, void *priv_sta)
1599{
1600	struct minstrel_ht_sta_priv *msp = priv_sta;
1601
1602	kfree(msp->sample_table);
1603	kfree(msp->ratelist);
1604	kfree(msp);
1605}
1606
1607static void
1608minstrel_ht_init_cck_rates(struct minstrel_priv *mp)
1609{
1610	static const int bitrates[4] = { 10, 20, 55, 110 };
1611	struct ieee80211_supported_band *sband;
1612	u32 rate_flags = ieee80211_chandef_rate_flags(&mp->hw->conf.chandef);
1613	int i, j;
1614
1615	sband = mp->hw->wiphy->bands[NL80211_BAND_2GHZ];
1616	if (!sband)
1617		return;
1618
1619	for (i = 0; i < sband->n_bitrates; i++) {
1620		struct ieee80211_rate *rate = &sband->bitrates[i];
1621
1622		if (rate->flags & IEEE80211_RATE_ERP_G)
1623			continue;
1624
1625		if ((rate_flags & sband->bitrates[i].flags) != rate_flags)
1626			continue;
1627
1628		for (j = 0; j < ARRAY_SIZE(bitrates); j++) {
1629			if (rate->bitrate != bitrates[j])
1630				continue;
1631
1632			mp->cck_rates[j] = i;
1633			break;
1634		}
1635	}
1636}
1637
1638static void *
1639minstrel_ht_alloc(struct ieee80211_hw *hw)
1640{
1641	struct minstrel_priv *mp;
1642
1643	mp = kzalloc(sizeof(struct minstrel_priv), GFP_ATOMIC);
1644	if (!mp)
1645		return NULL;
1646
1647	mp->sample_switch = -1;
1648
1649	/* contention window settings
1650	 * Just an approximation. Using the per-queue values would complicate
1651	 * the calculations and is probably unnecessary */
1652	mp->cw_min = 15;
1653	mp->cw_max = 1023;
1654
1655	/* number of packets (in %) to use for sampling other rates
1656	 * sample less often for non-mrr packets, because the overhead
1657	 * is much higher than with mrr */
1658	mp->lookaround_rate = 5;
1659	mp->lookaround_rate_mrr = 10;
1660
1661	/* maximum time that the hw is allowed to stay in one MRR segment */
1662	mp->segment_size = 6000;
1663
1664	if (hw->max_rate_tries > 0)
1665		mp->max_retry = hw->max_rate_tries;
1666	else
1667		/* safe default, does not necessarily have to match hw properties */
1668		mp->max_retry = 7;
1669
1670	if (hw->max_rates >= 4)
1671		mp->has_mrr = true;
1672
1673	mp->hw = hw;
1674	mp->update_interval = HZ / 10;
1675	mp->new_avg = true;
1676
1677	minstrel_ht_init_cck_rates(mp);
1678
1679	return mp;
1680}
1681
1682#ifdef CONFIG_MAC80211_DEBUGFS
1683static void minstrel_ht_add_debugfs(struct ieee80211_hw *hw, void *priv,
1684				    struct dentry *debugfsdir)
1685{
1686	struct minstrel_priv *mp = priv;
1687
1688	mp->fixed_rate_idx = (u32) -1;
1689	debugfs_create_u32("fixed_rate_idx", S_IRUGO | S_IWUGO, debugfsdir,
1690			   &mp->fixed_rate_idx);
1691	debugfs_create_u32("sample_switch", S_IRUGO | S_IWUSR, debugfsdir,
1692			   &mp->sample_switch);
1693	debugfs_create_bool("new_avg", S_IRUGO | S_IWUSR, debugfsdir,
1694			   &mp->new_avg);
1695}
1696#endif
1697
1698static void
1699minstrel_ht_free(void *priv)
1700{
1701	kfree(priv);
1702}
1703
1704static u32 minstrel_ht_get_expected_throughput(void *priv_sta)
1705{
1706	struct minstrel_ht_sta_priv *msp = priv_sta;
1707	struct minstrel_ht_sta *mi = &msp->ht;
1708	int i, j, prob, tp_avg;
1709
1710	if (!msp->is_ht)
1711		return mac80211_minstrel.get_expected_throughput(priv_sta);
1712
1713	i = mi->max_tp_rate[0] / MCS_GROUP_RATES;
1714	j = mi->max_tp_rate[0] % MCS_GROUP_RATES;
1715	prob = mi->groups[i].rates[j].prob_avg;
1716
1717	/* convert tp_avg from pkt per second in kbps */
1718	tp_avg = minstrel_ht_get_tp_avg(mi, i, j, prob) * 10;
1719	tp_avg = tp_avg * AVG_PKT_SIZE * 8 / 1024;
1720
1721	return tp_avg;
1722}
1723
1724static const struct rate_control_ops mac80211_minstrel_ht = {
1725	.name = "minstrel_ht",
1726	.tx_status_ext = minstrel_ht_tx_status,
1727	.get_rate = minstrel_ht_get_rate,
1728	.rate_init = minstrel_ht_rate_init,
1729	.rate_update = minstrel_ht_rate_update,
1730	.alloc_sta = minstrel_ht_alloc_sta,
1731	.free_sta = minstrel_ht_free_sta,
1732	.alloc = minstrel_ht_alloc,
1733	.free = minstrel_ht_free,
1734#ifdef CONFIG_MAC80211_DEBUGFS
1735	.add_debugfs = minstrel_ht_add_debugfs,
1736	.add_sta_debugfs = minstrel_ht_add_sta_debugfs,
1737#endif
1738	.get_expected_throughput = minstrel_ht_get_expected_throughput,
1739};
1740
1741
1742static void __init init_sample_table(void)
1743{
1744	int col, i, new_idx;
1745	u8 rnd[MCS_GROUP_RATES];
1746
1747	memset(sample_table, 0xff, sizeof(sample_table));
1748	for (col = 0; col < SAMPLE_COLUMNS; col++) {
1749		prandom_bytes(rnd, sizeof(rnd));
1750		for (i = 0; i < MCS_GROUP_RATES; i++) {
1751			new_idx = (i + rnd[i]) % MCS_GROUP_RATES;
1752			while (sample_table[col][new_idx] != 0xff)
1753				new_idx = (new_idx + 1) % MCS_GROUP_RATES;
1754
1755			sample_table[col][new_idx] = i;
1756		}
1757	}
1758}
1759
1760int __init
1761rc80211_minstrel_init(void)
1762{
1763	init_sample_table();
1764	return ieee80211_rate_control_register(&mac80211_minstrel_ht);
1765}
1766
1767void
1768rc80211_minstrel_exit(void)
1769{
1770	ieee80211_rate_control_unregister(&mac80211_minstrel_ht);
1771}
1772