1/*
2 * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28
29#ifndef LIBMPDEC_UMODARITH_H_
30#define LIBMPDEC_UMODARITH_H_
31
32
33#include "mpdecimal.h"
34
35#include "constants.h"
36#include "typearith.h"
37
38
39/* Bignum: Low level routines for unsigned modular arithmetic. These are
40   used in the fast convolution functions for very large coefficients. */
41
42
43/**************************************************************************/
44/*                        ANSI modular arithmetic                         */
45/**************************************************************************/
46
47
48/*
49 * Restrictions: a < m and b < m
50 * ACL2 proof: umodarith.lisp: addmod-correct
51 */
52static inline mpd_uint_t
53addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
54{
55    mpd_uint_t s;
56
57    s = a + b;
58    s = (s < a) ? s - m : s;
59    s = (s >= m) ? s - m : s;
60
61    return s;
62}
63
64/*
65 * Restrictions: a < m and b < m
66 * ACL2 proof: umodarith.lisp: submod-2-correct
67 */
68static inline mpd_uint_t
69submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
70{
71    mpd_uint_t d;
72
73    d = a - b;
74    d = (a < b) ? d + m : d;
75
76    return d;
77}
78
79/*
80 * Restrictions: a < 2m and b < 2m
81 * ACL2 proof: umodarith.lisp: section ext-submod
82 */
83static inline mpd_uint_t
84ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
85{
86    mpd_uint_t d;
87
88    a = (a >= m) ? a - m : a;
89    b = (b >= m) ? b - m : b;
90
91    d = a - b;
92    d = (a < b) ? d + m : d;
93
94    return d;
95}
96
97/*
98 * Reduce double word modulo m.
99 * Restrictions: m != 0
100 * ACL2 proof: umodarith.lisp: section dw-reduce
101 */
102static inline mpd_uint_t
103dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
104{
105    mpd_uint_t r1, r2, w;
106
107    _mpd_div_word(&w, &r1, hi, m);
108    _mpd_div_words(&w, &r2, r1, lo, m);
109
110    return r2;
111}
112
113/*
114 * Subtract double word from a.
115 * Restrictions: a < m
116 * ACL2 proof: umodarith.lisp: section dw-submod
117 */
118static inline mpd_uint_t
119dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
120{
121    mpd_uint_t d, r;
122
123    r = dw_reduce(hi, lo, m);
124    d = a - r;
125    d = (a < r) ? d + m : d;
126
127    return d;
128}
129
130#ifdef CONFIG_64
131
132/**************************************************************************/
133/*                        64-bit modular arithmetic                       */
134/**************************************************************************/
135
136/*
137 * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
138 * proof is in umodarith.lisp: section "Fast modular reduction".
139 *
140 * Algorithm: calculate (a * b) % p:
141 *
142 *   a) hi, lo <- a * b       # Calculate a * b.
143 *
144 *   b) hi, lo <-  R(hi, lo)  # Reduce modulo p.
145 *
146 *   c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
147 *
148 *   d) If the result is less than p, return lo. Otherwise return lo - p.
149 */
150
151static inline mpd_uint_t
152x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
153{
154    mpd_uint_t hi, lo, x, y;
155
156
157    _mpd_mul_words(&hi, &lo, a, b);
158
159    if (m & (1ULL<<32)) { /* P1 */
160
161        /* first reduction */
162        x = y = hi;
163        hi >>= 32;
164
165        x = lo - x;
166        if (x > lo) hi--;
167
168        y <<= 32;
169        lo = y + x;
170        if (lo < y) hi++;
171
172        /* second reduction */
173        x = y = hi;
174        hi >>= 32;
175
176        x = lo - x;
177        if (x > lo) hi--;
178
179        y <<= 32;
180        lo = y + x;
181        if (lo < y) hi++;
182
183        return (hi || lo >= m ? lo - m : lo);
184    }
185    else if (m & (1ULL<<34)) { /* P2 */
186
187        /* first reduction */
188        x = y = hi;
189        hi >>= 30;
190
191        x = lo - x;
192        if (x > lo) hi--;
193
194        y <<= 34;
195        lo = y + x;
196        if (lo < y) hi++;
197
198        /* second reduction */
199        x = y = hi;
200        hi >>= 30;
201
202        x = lo - x;
203        if (x > lo) hi--;
204
205        y <<= 34;
206        lo = y + x;
207        if (lo < y) hi++;
208
209        /* third reduction */
210        x = y = hi;
211        hi >>= 30;
212
213        x = lo - x;
214        if (x > lo) hi--;
215
216        y <<= 34;
217        lo = y + x;
218        if (lo < y) hi++;
219
220        return (hi || lo >= m ? lo - m : lo);
221    }
222    else { /* P3 */
223
224        /* first reduction */
225        x = y = hi;
226        hi >>= 24;
227
228        x = lo - x;
229        if (x > lo) hi--;
230
231        y <<= 40;
232        lo = y + x;
233        if (lo < y) hi++;
234
235        /* second reduction */
236        x = y = hi;
237        hi >>= 24;
238
239        x = lo - x;
240        if (x > lo) hi--;
241
242        y <<= 40;
243        lo = y + x;
244        if (lo < y) hi++;
245
246        /* third reduction */
247        x = y = hi;
248        hi >>= 24;
249
250        x = lo - x;
251        if (x > lo) hi--;
252
253        y <<= 40;
254        lo = y + x;
255        if (lo < y) hi++;
256
257        return (hi || lo >= m ? lo - m : lo);
258    }
259}
260
261static inline void
262x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
263{
264    *a = x64_mulmod(*a, w, m);
265    *b = x64_mulmod(*b, w, m);
266}
267
268static inline void
269x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
270            mpd_uint_t m)
271{
272    *a0 = x64_mulmod(*a0, b0, m);
273    *a1 = x64_mulmod(*a1, b1, m);
274}
275
276static inline mpd_uint_t
277x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
278{
279    mpd_uint_t r = 1;
280
281    while (exp > 0) {
282        if (exp & 1)
283            r = x64_mulmod(r, base, umod);
284        base = x64_mulmod(base, base, umod);
285        exp >>= 1;
286    }
287
288    return r;
289}
290
291/* END CONFIG_64 */
292#else /* CONFIG_32 */
293
294
295/**************************************************************************/
296/*                        32-bit modular arithmetic                       */
297/**************************************************************************/
298
299#if defined(ANSI)
300#if !defined(LEGACY_COMPILER)
301/* HAVE_UINT64_T */
302static inline mpd_uint_t
303std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
304{
305    return ((mpd_uuint_t) a * b) % m;
306}
307
308static inline void
309std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
310{
311    *a = ((mpd_uuint_t) *a * w) % m;
312    *b = ((mpd_uuint_t) *b * w) % m;
313}
314
315static inline void
316std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
317            mpd_uint_t m)
318{
319    *a0 = ((mpd_uuint_t) *a0 * b0) % m;
320    *a1 = ((mpd_uuint_t) *a1 * b1) % m;
321}
322/* END HAVE_UINT64_T */
323#else
324/* LEGACY_COMPILER */
325static inline mpd_uint_t
326std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
327{
328    mpd_uint_t hi, lo, q, r;
329    _mpd_mul_words(&hi, &lo, a, b);
330    _mpd_div_words(&q, &r, hi, lo, m);
331    return r;
332}
333
334static inline void
335std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
336{
337    *a = std_mulmod(*a, w, m);
338    *b = std_mulmod(*b, w, m);
339}
340
341static inline void
342std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
343            mpd_uint_t m)
344{
345    *a0 = std_mulmod(*a0, b0, m);
346    *a1 = std_mulmod(*a1, b1, m);
347}
348/* END LEGACY_COMPILER */
349#endif
350
351static inline mpd_uint_t
352std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
353{
354    mpd_uint_t r = 1;
355
356    while (exp > 0) {
357        if (exp & 1)
358            r = std_mulmod(r, base, umod);
359        base = std_mulmod(base, base, umod);
360        exp >>= 1;
361    }
362
363    return r;
364}
365#endif /* ANSI CONFIG_32 */
366
367
368/**************************************************************************/
369/*                    Pentium Pro modular arithmetic                      */
370/**************************************************************************/
371
372/*
373 * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
374 * control word must be set to 64-bit precision and truncation mode
375 * prior to using these functions.
376 *
377 * Algorithm: calculate (a * b) % p:
378 *
379 *   p    := prime < 2**31
380 *   pinv := (long double)1.0 / p (precalculated)
381 *
382 *   a) n = a * b              # Calculate exact product.
383 *   b) qest = n * pinv        # Calculate estimate for q = n / p.
384 *   c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
385 *   d) r = n - q * p          # Calculate remainder.
386 *
387 * Remarks:
388 *
389 *   - p = dmod and pinv = dinvmod.
390 *   - dinvmod points to an array of three uint32_t, which is interpreted
391 *     as an 80 bit long double by fldt.
392 *   - Intel compilers prior to version 11 do not seem to handle the
393 *     __GNUC__ inline assembly correctly.
394 *   - random tests are provided in tests/extended/ppro_mulmod.c
395 */
396
397#if defined(PPRO)
398#if defined(ASM)
399
400/* Return (a * b) % dmod */
401static inline mpd_uint_t
402ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
403{
404    mpd_uint_t retval;
405
406    __asm__ (
407            "fildl  %2\n\t"
408            "fildl  %1\n\t"
409            "fmulp  %%st, %%st(1)\n\t"
410            "fldt   (%4)\n\t"
411            "fmul   %%st(1), %%st\n\t"
412            "flds   %5\n\t"
413            "fadd   %%st, %%st(1)\n\t"
414            "fsubrp %%st, %%st(1)\n\t"
415            "fldl   (%3)\n\t"
416            "fmulp  %%st, %%st(1)\n\t"
417            "fsubrp %%st, %%st(1)\n\t"
418            "fistpl %0\n\t"
419            : "=m" (retval)
420            : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
421            : "st", "memory"
422    );
423
424    return retval;
425}
426
427/*
428 * Two modular multiplications in parallel:
429 *      *a0 = (*a0 * w) % dmod
430 *      *a1 = (*a1 * w) % dmod
431 */
432static inline void
433ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
434              double *dmod, uint32_t *dinvmod)
435{
436    __asm__ (
437            "fildl  %2\n\t"
438            "fildl  (%1)\n\t"
439            "fmul   %%st(1), %%st\n\t"
440            "fxch   %%st(1)\n\t"
441            "fildl  (%0)\n\t"
442            "fmulp  %%st, %%st(1) \n\t"
443            "fldt   (%4)\n\t"
444            "flds   %5\n\t"
445            "fld    %%st(2)\n\t"
446            "fmul   %%st(2)\n\t"
447            "fadd   %%st(1)\n\t"
448            "fsub   %%st(1)\n\t"
449            "fmull  (%3)\n\t"
450            "fsubrp %%st, %%st(3)\n\t"
451            "fxch   %%st(2)\n\t"
452            "fistpl (%0)\n\t"
453            "fmul   %%st(2)\n\t"
454            "fadd   %%st(1)\n\t"
455            "fsubp  %%st, %%st(1)\n\t"
456            "fmull  (%3)\n\t"
457            "fsubrp %%st, %%st(1)\n\t"
458            "fistpl (%1)\n\t"
459            : : "r" (a0), "r" (a1), "m" (w),
460                "r" (dmod), "r" (dinvmod),
461                "m" (MPD_TWO63)
462            : "st", "memory"
463    );
464}
465
466/*
467 * Two modular multiplications in parallel:
468 *      *a0 = (*a0 * b0) % dmod
469 *      *a1 = (*a1 * b1) % dmod
470 */
471static inline void
472ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
473             double *dmod, uint32_t *dinvmod)
474{
475    __asm__ (
476            "fildl  %3\n\t"
477            "fildl  (%2)\n\t"
478            "fmulp  %%st, %%st(1)\n\t"
479            "fildl  %1\n\t"
480            "fildl  (%0)\n\t"
481            "fmulp  %%st, %%st(1)\n\t"
482            "fldt   (%5)\n\t"
483            "fld    %%st(2)\n\t"
484            "fmul   %%st(1), %%st\n\t"
485            "fxch   %%st(1)\n\t"
486            "fmul   %%st(2), %%st\n\t"
487            "flds   %6\n\t"
488            "fldl   (%4)\n\t"
489            "fxch   %%st(3)\n\t"
490            "fadd   %%st(1), %%st\n\t"
491            "fxch   %%st(2)\n\t"
492            "fadd   %%st(1), %%st\n\t"
493            "fxch   %%st(2)\n\t"
494            "fsub   %%st(1), %%st\n\t"
495            "fxch   %%st(2)\n\t"
496            "fsubp  %%st, %%st(1)\n\t"
497            "fxch   %%st(1)\n\t"
498            "fmul   %%st(2), %%st\n\t"
499            "fxch   %%st(1)\n\t"
500            "fmulp  %%st, %%st(2)\n\t"
501            "fsubrp %%st, %%st(3)\n\t"
502            "fsubrp %%st, %%st(1)\n\t"
503            "fxch   %%st(1)\n\t"
504            "fistpl (%2)\n\t"
505            "fistpl (%0)\n\t"
506            : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
507                "r" (dmod), "r" (dinvmod),
508                "m" (MPD_TWO63)
509            : "st", "memory"
510    );
511}
512/* END PPRO GCC ASM */
513#elif defined(MASM)
514
515/* Return (a * b) % dmod */
516static inline mpd_uint_t __cdecl
517ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
518{
519    mpd_uint_t retval;
520
521    __asm {
522        mov     eax, dinvmod
523        mov     edx, dmod
524        fild    b
525        fild    a
526        fmulp   st(1), st
527        fld     TBYTE PTR [eax]
528        fmul    st, st(1)
529        fld     MPD_TWO63
530        fadd    st(1), st
531        fsubp   st(1), st
532        fld     QWORD PTR [edx]
533        fmulp   st(1), st
534        fsubp   st(1), st
535        fistp   retval
536    }
537
538    return retval;
539}
540
541/*
542 * Two modular multiplications in parallel:
543 *      *a0 = (*a0 * w) % dmod
544 *      *a1 = (*a1 * w) % dmod
545 */
546static inline mpd_uint_t __cdecl
547ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
548              double *dmod, uint32_t *dinvmod)
549{
550    __asm {
551        mov     ecx, dmod
552        mov     edx, a1
553        mov     ebx, dinvmod
554        mov     eax, a0
555        fild    w
556        fild    DWORD PTR [edx]
557        fmul    st, st(1)
558        fxch    st(1)
559        fild    DWORD PTR [eax]
560        fmulp   st(1), st
561        fld     TBYTE PTR [ebx]
562        fld     MPD_TWO63
563        fld     st(2)
564        fmul    st, st(2)
565        fadd    st, st(1)
566        fsub    st, st(1)
567        fmul    QWORD PTR [ecx]
568        fsubp   st(3), st
569        fxch    st(2)
570        fistp   DWORD PTR [eax]
571        fmul    st, st(2)
572        fadd    st, st(1)
573        fsubrp  st(1), st
574        fmul    QWORD PTR [ecx]
575        fsubp   st(1), st
576        fistp   DWORD PTR [edx]
577    }
578}
579
580/*
581 * Two modular multiplications in parallel:
582 *      *a0 = (*a0 * b0) % dmod
583 *      *a1 = (*a1 * b1) % dmod
584 */
585static inline void __cdecl
586ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
587             double *dmod, uint32_t *dinvmod)
588{
589    __asm {
590        mov     ecx, dmod
591        mov     edx, a1
592        mov     ebx, dinvmod
593        mov     eax, a0
594        fild    b1
595        fild    DWORD PTR [edx]
596        fmulp   st(1), st
597        fild    b0
598        fild    DWORD PTR [eax]
599        fmulp   st(1), st
600        fld     TBYTE PTR [ebx]
601        fld     st(2)
602        fmul    st, st(1)
603        fxch    st(1)
604        fmul    st, st(2)
605        fld     DWORD PTR MPD_TWO63
606        fld     QWORD PTR [ecx]
607        fxch    st(3)
608        fadd    st, st(1)
609        fxch    st(2)
610        fadd    st, st(1)
611        fxch    st(2)
612        fsub    st, st(1)
613        fxch    st(2)
614        fsubrp  st(1), st
615        fxch    st(1)
616        fmul    st, st(2)
617        fxch    st(1)
618        fmulp   st(2), st
619        fsubp   st(3), st
620        fsubp   st(1), st
621        fxch    st(1)
622        fistp   DWORD PTR [edx]
623        fistp   DWORD PTR [eax]
624    }
625}
626#endif /* PPRO MASM (_MSC_VER) */
627
628
629/* Return (base ** exp) % dmod */
630static inline mpd_uint_t
631ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
632{
633    mpd_uint_t r = 1;
634
635    while (exp > 0) {
636        if (exp & 1)
637            r = ppro_mulmod(r, base, dmod, dinvmod);
638        base = ppro_mulmod(base, base, dmod, dinvmod);
639        exp >>= 1;
640    }
641
642    return r;
643}
644#endif /* PPRO */
645#endif /* CONFIG_32 */
646
647
648#endif /* LIBMPDEC_UMODARITH_H_ */
649