/* mpn_cube -- Compute the third power (cube) of {ap,an} with Zanoni's
   algorithm.

Written by Alberto Zanoni, based on code written by Marco Bodrato,
Torbjorn Granlund, Robert Harley, Niels Möller, Paul Zimmermann.

The functions mpn_cube_eval_dgr3_pm1, mpn_cube_interpolate_5pts,
mpn_cube, contain code by the author mixed with code from
GNU MP-sources. The different main() programs can have contributions by
different authors.

The GNU MP Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

The GNU MP Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */

/*
COMPILATION: Put this code in the directory where GMP lib has
             previously been compiled and installed.

- Compilation command to check if code works properly:

gcc -lm -O3 -DCHECK cubeNew.c .libs/libgmp.a -o cubeNew
./cubeNew 30
*/

#include "gmp.h"
#include "gmp-impl.h"

#if defined(CHECK)
#include <stdlib.h>
#include <stdio.h>
#include <math.h>

void
dumpy (mp_srcptr p, mp_size_t n)
{
  mp_size_t i;
  for (i = n - 1; i >= 0; i--)
    {
      printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]);
      printf (" " + (i == 0));
    }
  puts ("");
}
#endif

#define TOOM_MUL_N_REC mpn_mul_n

/* Auxiliary function: compute evaluation of a polynomial of degree 3
   in 1 and -1, for Zanoni's algorithm. The 1st and 2nd (least
   meaningful) coefficient have n+1 limbs (the 1st one has the form
   27*u, the 2nd one 3*v), the 3rd one has n limbs, the 4th has s <= n
   limbs.

   **** PARAMETERS ****
   resultEvaluationIn1         : pointer to the memory area that will contain the evaluation in 1: n+1.
   resultEvaluationInMinus1    : idem, for -1 : n+1.
   firstCoefficient            : pointer to the memory area of the first coefficient: n+1 limbs

   secondThirdFourthCoefficient: pointer to the memory area of the second coefficient
                                 (n+1 limbs), followed by the 3rd (n) and the 4th (s) one.
   n                           : The value of n
   s                           : The value of s (n or n-1)
   temp                        : A temporary variable, for intermediate results (n+1 limbs).

                                  <-n+1->
   firstCoefficient             : | c0  |
                                  <-s-><-n-><-n+1->       (s = n or n-1)
   secondThirdFourthCoefficient : | c3| c2 |  c1  |

   **** RETURNED VALUE ****
    0   if c0 - c1 + c2 - c3 >= 0
   ~0   otherwise
*/
static int
mpn_cube_eval_dgr3_pm1 (mp_ptr    resultEvaluationIn1,
			mp_ptr    resultEvaluationInMinus1,
			mp_srcptr firstCoefficient,
			mp_srcptr secondThirdFourthCoefficient,
			mp_size_t n, mp_size_t s,
			mp_ptr temp )
{
  int neg;

#define C0 firstCoefficient
#define C1 secondThirdFourthCoefficient
#define C2 secondThirdFourthCoefficient + n + 1
#define C3 secondThirdFourthCoefficient + 2*n + 1

  resultEvaluationIn1[n] = C0[n] + mpn_add_n (resultEvaluationIn1, C0, C2, n); /* C0 + C2 */
  temp[n]                = C1[n] + ((s == n) ? mpn_add_n (temp, C1, C3, n) :
                                               mpn_add(temp, C1, n, C3, s));   /* C1 + C3 */
#undef C0
#undef C1
#undef C2
#undef C3
  /* Check if the evaluation in -1 is negative. */
  neg = (mpn_cmp (resultEvaluationIn1, temp, n + 1) < 0) ? ~0 : 0;

#if HAVE_NATIVE_mpn_add_n_sub_n
  if (neg)
    mpn_add_n_sub_n (resultEvaluationIn1, resultEvaluationInMinus1, temp, resultEvaluationIn1, n + 1);
  else
    mpn_add_n_sub_n (resultEvaluationIn1, resultEvaluationInMinus1, resultEvaluationIn1, temp, n + 1);
#else
  if (neg)
    mpn_sub_n (resultEvaluationInMinus1, temp, resultEvaluationIn1, n + 1);
  else
    mpn_sub_n (resultEvaluationInMinus1, resultEvaluationIn1, temp, n + 1);

  mpn_add_n (resultEvaluationIn1, resultEvaluationIn1, temp, n + 1);
#endif
  return neg;
}

/*
   Very similar to mpn_toom_interpolate_5pts, but
1) Interpolation is sligthly changed: a mpn_submul_1 is used instead
   of a simple subtraction and an extra division by 9 is needed.
2) Interpolation and Recomposition phases are not mixed.
*/
static void
mpn_cube_interpolate_5pts (mp_ptr c, mp_ptr v2, mp_ptr vm1,
			   mp_size_t k, mp_size_t twor, int signOfMinus1,
			   mp_limb_t vinf0)
{
  mp_limb_t cy, saved;
  mp_size_t twok, kk1;
  mp_ptr    c1, v1, c3, vinf;

  twok = k<<1;
  kk1  = twok + 1;
  c1   = c  + k;
  v1   = c1 + k;
  c3   = v1 + k;
  vinf = c3 + k;

#define v0 (c)
  /* (1) v2 <- v2-vm1 < v2+|vm1|, (16 8 4 2 1) - (1 -1 1 -1 1) =
                                                 (15 9 3  3 0) */
  if (signOfMinus1)
    ASSERT_NOCARRY (mpn_add_n (v2, v2, vm1, kk1));
  else
    ASSERT_NOCARRY (mpn_sub_n (v2, v2, vm1, kk1));

  /* {c,2k} {c+2k,2k+1} {c+4k+1,2r-1} {t,2k+1} {t+2k+1,2k+1} {t+4k+2,2r}
       v0       v1       hi(vinf)       |vm1|     v2-vm1      EMPTY */

  ASSERT_NOCARRY (mpn_divexact_by3 (v2, v2, kk1));    /* v2 <- v2 / 3 */
						      /* (5 3 1 1 0)*/
  /* {c,2k} {c+2k,2k+1} {c+4k+1,2r-1} {t,2k+1} {t+2k+1,2k+1} {t+4k+2,2r}
       v0       v1      hi(vinf)       |vm1|     (v2-vm1)/3    EMPTY */

  /* (2) vm1 <- tm1 := (v1 - vm1)/2    [(1 1 1 1 1) - (1 -1 1 -1 1)]/2 =
     tm1 >= 0                                         (0  1 0  1 0)
     No carry comes out from {v1, kk1} +/- {vm1, kk1},
     and the division by two is exact.
     If (signOfMinus1 != 0) the sign of vm1 is negative */
  if (signOfMinus1)
    {
#ifdef HAVE_NATIVE_mpn_rsh1add_n
      mpn_rsh1add_n (vm1, v1, vm1, kk1);
#else
      ASSERT_NOCARRY (mpn_add_n (vm1, v1, vm1, kk1));
      ASSERT_NOCARRY (mpn_rshift (vm1, vm1, kk1, 1));
#endif
    }
  else
    {
#ifdef HAVE_NATIVE_mpn_rsh1sub_n
      mpn_rsh1sub_n (vm1, v1, vm1, kk1);
#else
      ASSERT_NOCARRY (mpn_sub_n (vm1, v1, vm1, kk1));
      ASSERT_NOCARRY (mpn_rshift (vm1, vm1, kk1, 1));
#endif
    }
  /* {c,2k} {c+2k,2k+1} {c+4k+1,2r-1} {t,2k+1} {t+2k+1,2k+1} {t+4k+2,2r}
       v0       v1        hi(vinf)       tm1     (v2-vm1)/3    EMPTY */

  /* (3) v1 <- t1 := v1 - v0    (1 1 1 1 1) - (0 0 0 0 1) = (1 1 1 1 0)
     t1 >= 0
  */
  vinf[0] -= mpn_submul_1 (v1, c, twok, 81);
  /* {c,2k} {c+2k,2k+1} {c+4k+1,2r-1} {t,2k+1} {t+2k+1,2k+1} {t+4k+2,2r}
       v0     v1-v0        hi(vinf)       tm1     (v2-vm1)/3    EMPTY */

  /* (4) v2 <- t2 := ((v2-vm1)/3-t1)/2 = (v2-vm1-3*t1)/6
     t2 >= 0                  [(5 3 1 1 0) - (1 1 1 1 0)]/2 = (2 1 0 0 0)
  */
#ifdef HAVE_NATIVE_mpn_rsh1sub_n
  mpn_rsh1sub_n (v2, v2, v1, kk1);
#else
  ASSERT_NOCARRY (mpn_sub_n (v2, v2, v1, kk1));
  ASSERT_NOCARRY (mpn_rshift (v2, v2, kk1, 1));
#endif
  /* {c,2k} {c+2k,2k+1} {c+4k+1,2r-1} {t,2k+1} {t+2k+1,2k+1} {t+4k+2,2r}
       v0     v1-v0        hi(vinf)     tm1    (v2-vm1-3t1)/6    EMPTY */

  /* (5) v1 <- t1-tm1           (1 1 1 1 0) - (0 1 0 1 0) = (1 0 1 0 0)
     result is v1 >= 0
  */
  ASSERT_NOCARRY (mpn_sub_n (v1, v1, vm1, kk1));
  /* (6) v2 <- v2 - 2*vinf,     (2 1 0 0 0) - 2*(1 0 0 0 0) = (0 1 0 0 0)
     result is v2 >= 0 */
  saved = vinf[0]; /* Remember v1's highest byte (will be overwritten). */
  vinf[0] = vinf0; /* Set the right value for vinf0                     */
#ifdef HAVE_NATIVE_mpn_sublsh1_n
  cy = mpn_sublsh1_n (v2, v2, vinf, twor);
#else
  cy = mpn_submul_1 (v2, vinf, twor, 2);
#endif
  MPN_DECR_U (v2 + twor, kk1 - twor, cy);
  /* Current matrix is
     [1 0 0 0 0   ; vinf            
      0 1 0 0 0   ; v2           We have to compute
      1 0 1 0 0   ; v1           v1 -= vinf; vm1 -= v2,
      0 1 0 1 0   ; vm1          divide vm1 by 9 and then
      0 0 0 0 1/81] v0		 recompose everything.

     (7) vm1 <- vm1-v2          (0 1 0 1 0) - (0 1 0 0 0) = (0 0 0 1 0) */
  mpn_sub_n (vm1, vm1, v2, kk1);
  /* (8) vm1 <- vm1/9           ADDED STEP                              */
  mpn_divexact_by3 (vm1, vm1, kk1);    /* vm1 <- vm1 / 3 */
  mpn_divexact_by3 (vm1, vm1, kk1);    /* vm1 <- vm1 / 3 */
  /* (9) v1 <- v1 - vinf,       (1 0 1 0 0) - (1 0 0 0 0) = (0 0 1 0 0)
     result is >= 0 */
  cy = mpn_sub_n (v1, v1, vinf, twor);    /* vinf is at most twor long. */
  vinf0 = vinf[0];                 /* Save again the right vinf0 value. */
  vinf[0] = saved;
  MPN_DECR_U (v1 + twor, kk1 - twor, cy);  /* v1's highest limbs.       */

  /* Add vm1 in {c+k,...} */
  cy = mpn_add_n (c1, c1, vm1, kk1);
  MPN_INCR_U (c3 + 1, twor + k - 1, cy);   /* 2n-(3k+1) = 2r+k-1        */

  /* Add v2 in {c+3*k,...} */
  cy = mpn_add_n (c3, c3, v2, kk1);
  MPN_INCR_U (c3 + kk1, twor - k + 1, cy); /* 2r-k+1                    */

  /* Final remaining part. */
  MPN_INCR_U (vinf, twor, vinf0);        /* Add vinf0, propagate carry. */

#undef v0
}

/* Extra memory needed for temporaries by mpn_cube function.            */

mp_size_t mpn_cube_itch(mp_size_t an)
{
  return (((an+1)>>1)*6+5);
}

/* Compute the cube (third power) of a with Zanoni's algorithm.

      <-t-><--n-->            t = n  or  t = n-1
a  :  | a1 |  a0 |

      <----3t----><------3n------>
a^3:  |   a1^3   |     a0^3      |     (a1|a0)^3
             |   3a1a0^2   |<-n->
          |  3a1a0^2  |<-n->

In the following description, s = n, for simplicity.

"Classical" algorithm: a^3 = (a^2)*a, with quadratic complexity C(2n):

1)  a^2    : a square of a O(2n)-limbs number: S(2n)
2) (a^2)*a : a product O(4n) x O(2n)         : P(4n, 2n)

Counting only quadratic operations: using Karatsuba for 1) and
unbalanced Toom-3 - toom42 - for 2) one has, respectively,
S(2n) = 3S(n) and P(4n,2n) = 5P(n), so that

                         C(2n) = 3S(n) + 5P(n)
---------------------------------------------------------------------
Zanoni's algorithm (where Karatsuba and unbalanced Toom-3 are worth):

1) Compute A  = a0^2 = | A1 | A0 |                       : S(n)
2) Compute A' = a1^2 = | A3 | A2 |                       : S(n)
3) Consider polynomials f(X) = A3*X^3 + A2*X^2 + 3*A1*X + 27*A0
                        g(X) = a1*X + 3*a0
   and compute H(X) = f(X)*g(X) = h4*X^4 + 3*X^3 + h2*X^2 + H1*X + H0
   with unbalanced Toom-3 in (0,-1,1,2,oo)               : 5P(n)
4) Set h1 = H1/9,  h0 = H0/81
5) Consider h(X) = h4*X^4 + h3*X^3 + h2*X^2 + h1*x + h0 and recompose

   | h4 | h2 | h0 |      C(2n) = 2P(n) + 5P(n)
      | h3 | h1 |

  Returns the number of meaningful limbs (3*an or 3*an-1 or 3*an-2).
*/

mp_size_t
mpn_cube (mp_ptr pp, mp_srcptr ap, mp_size_t an, mp_ptr scratch)
{
  mp_size_t    n, t;
  int          vm1_neg, vinf0, cy, cy2;
  mp_limb_t    keep;

  if (an < 4)
    return mpn_pow_1(pp, ap, an, 3, scratch);

  t = an >> 1;  /* Number of limbs of a1 (either n or n-1) */
  n = an - t;   /* Number of limbs of a0.                  */

  ASSERT (n-2 < t && t <= n);

#define a1      (ap + n)  /* Pointer to high section of a. */
#define A0      (scratch)            /* n    */
#define A1      (scratch + n)        /* n+1  */
#define A2      (scratch + 2*n + 1)  /* n    */
#define A3      (scratch + 3*n + 1)  /* 2t-n */
#define as1     (scratch + 4*n + 3)  /* n+1  */
#define as2     (scratch + 5*n + 4)  /* n+1  */
#define asm1    (pp)
#define tmp     (pp + n + 1)         /* n+1  */
#define vinf	(pp + 4 * n)         /* 2t   */

  /* Computation of a0^2 and a1^2.                    */  
  mpn_sqr(A0, ap, n); /* 2n */
  mpn_sqr(A2, a1, t); /* 2t */
  /* Coefficients adjustment.                         */  
  tmp[n] = mpn_mul_1(tmp, A0, n, 27);  /* 27*A0 : n+1 */
  A1 [n] = mpn_mul_1(A1 , A1, n,  3);  /*  3*A1 : n+1 */
  /*    | 27A0 |              | A3 | A2 | 3A1 | A0 |  */

  /* vinf, 3*t-n limbs */
  if ( t == n )
    mpn_mul_n (vinf, a1, A3, n);
  else
    mpn_mul (vinf, a1, t, A3, 2*t-n);

  /* Unbalanced Toom-3 : Evaluation phase.            */
  /* Evaluation of first factor in 2. Ruffini-Horner. */
#if HAVE_NATIVE_mpn_addlsh1_n
  cy  = mpn_addlsh1_n (as2, A2, A3, 2*t-n);
  if (t != n)
    cy = mpn_add_1 (as2 + (2*t-n), A2 + (2*t-n), n - (2*t-n), cy);
  cy = 2 * cy + A1[n] + mpn_addlsh1_n (as2, A1, as2, n);
  cy = 2 * cy + tmp[n] + mpn_addlsh1_n (as2, tmp, as2, n);
#else
  cy  = mpn_lshift (as2, A3, 2*t-n, 1);           /* 2*A3  */
  cy += mpn_add_n (as2, A2, as2, 2*t-n);
  if (t != n)
    cy = mpn_add_1 (as2 + (2*t-n), A2 + (2*t-n), n - (2*t-n), cy); /* 2*A3 + A2 */
  cy = 2 * cy + mpn_lshift (as2, as2, n, 1);      /* 4*A3 + 2*A2                */
  cy += A1[n] + mpn_add_n (as2, A1, as2, n);      /* 4*A3 + 2*A2 + A1           */
  cy = 2 * cy + mpn_lshift (as2, as2, n, 1);      /* 8*A3 + 4*A2 + 2*A1         */
  cy += tmp[n] + mpn_add_n (as2, tmp, as2, n);    /* 8*A3 + 4*A2 + 6*A1 + 27*A0 */
#endif
  as2[n] = cy;

  /* Evaluation of first factor in 1 and -1. */
  vm1_neg = mpn_cube_eval_dgr3_pm1(as1, asm1, tmp, A1, n, 2*t-n, tmp+n+1) & 1;
  /* Now 27A0 is needed no more: recycle its space. */
  /*    | 3a0 |          | A3 | A2 | 3A1 | A0 |     */
  tmp[n] = mpn_mul_1(tmp, ap, n, 3); /* 3*a0: n+1   */
  /* Evaluation of second factor in -1.             */
  if (mpn_zero_p (tmp + t, n + 1 - t) && mpn_cmp (tmp, a1, t) < 0)
    {
      mpn_sub_n (A1, a1, tmp, t);
      MPN_ZERO (A1 + t, n + 1 - t);
      vm1_neg ^= 1;
    }
  else
    {
      mpn_sub (A1, tmp, n + 1, a1, t);
    }                                  /* a1 - 3*a0 */

#define v0	(pp)	               /* 2n        */
#define v1	(pp + 2 * n)           /* 2n+1      */
#define v2	(A0)

  /* v1, 2n+1 limbs */
  vinf0 = vinf[0];           /* v1 overlaps with these.         */
  cy    = vinf[1];
  mpn_add (A2, tmp, n+1, a1, t);                   /* a1 + 3*a0 */
  TOOM_MUL_N_REC (v1, as1, A2, n+1);
  vinf[1] = cy;              /* Restore the highest overwritten */
                             /* limb (by 0)                     */
  /* Evaluation of second factor in 2.                          */
  mpn_add (as1, A2, n + 1, a1, t); /* as1 = 2*a1 + 3*a0         */
  /* vm1, 2n+1 limbs. One more limb is needed for multiplication, but
     it will contain 0.      */
  TOOM_MUL_N_REC (A2, asm1, A1, n+1);
  /* v0, 2n limbs            */
  TOOM_MUL_N_REC (v0, ap, A0, n);
  /* v2, 2*n+1 limbs. One more limb is needed for multiplication, but
     it will contain 0.      */
  cy = A2[0];
  TOOM_MUL_N_REC (v2, as2, as1, n + 1);
  A2[0] = cy;

  /* Unbalanced Toom-3: Interpolation and Recomposition phases. */
  mpn_cube_interpolate_5pts (pp, v2, A2, n, 3*t-n, vm1_neg, vinf0);

  return (3*an - (pp[3*an-1] == 0) - ((pp[3*an-1] == 0) && (pp[3*an-2] == 0)));

#undef a1
#undef as1
#undef asm1
#undef as2

#undef v0   
#undef vm1  
#undef v1   
#undef v2   
#undef vinf 

#undef A0
#undef A1  
#undef A2
#undef A3
}

#if defined(CHECK)

#ifdef CHECK
#define M 3
#else /* TIMING... accept any size (also not managed ones) */
#define M 1
#endif

#define MINN 1

#ifdef CHECK

#ifndef SIZE
#define SIZE 2
#endif

#ifndef MINN
#define MINN 2
#endif

/********* CHECKING PROGRAM ********/

int
main (int argc, char **argv)
{
  mp_size_t n, s, t, an, clearn;
  mp_ptr ap, refp, pp, tmppp;
  mp_limb_t keep;
  mp_size_t cubeLength;
  int test;
  int maxn;
  int norandom;
  int err = 0;
  TMP_DECL;
  TMP_MARK;

  printf("-----------------------------------------\n");
  an = M * SIZE;
  norandom = 0;

  if (argc >= 2)
    {
      maxn = strtol (argv[1], 0, 0); /* Max length. */
      an = M * maxn;
      if (argc == 3)
	{
	  an = maxn;
	  norandom = 1;
	}
    }
  else
    return 1;

  ap    = TMP_ALLOC_LIMBS (an);
  refp  = TMP_ALLOC_LIMBS (3*an);
  pp    = TMP_ALLOC_LIMBS (3*an + 1+1);
  tmppp = TMP_ALLOC_LIMBS (11*an);

  for (test = 0;; test++)
    {
      if (! norandom)
	{
	  n = random () % (maxn-MINN) + MINN+1;
	  s = random () % n + 1;
	  an = (M - 1) * n + s;
	}
      if (err == 0 && test % 0x10000 == 0)
	{
	  printf ("\r%d\tan = %d\tn = %d\tt = %d ", test, (int) an, (((int) an)+1)/2, ((int) an)>>1);
	  fflush (stdout);
	}
      mpn_random2 (ap, an);
      clearn = random () % (an + 1);
      MPN_ZERO (ap + clearn, an - clearn);

      mpn_random2 (pp, 3*an + 1);       /* Random data in the result space. */
      keep = pp[3*an];                  /* Keep the highest limb.           */

      mpn_cube (pp, ap, an, tmppp);                  /* New cube algorithm. */
      cubeLength = mpn_pow_1(refp, ap, an, 3, tmppp);/* Standard algorithm. */
      if (pp[3*an] != keep || mpn_cmp (refp, pp, cubeLength) != 0)
	{                                            /* In case of error... */
	  printf ("\nERROR in test %d\t", test);
	  printf ("an = %d\tn = %d\tt = %d \n", (int) an, (((int) an)+1)/2, ((int) an)>>1);
	  if (pp[3*an] != keep)
	    {
	      printf ("pp high : "); dumpy (pp + 3*an, 1);
	      printf ("keep    : "); dumpy (&keep, 1);
	    }
	  printf("a    = "); dumpy (ap, an);
	  printf("cube = "); dumpy (pp, 3*an);
	  printf("GMP  = "); dumpy (refp, 3*an);
	  if (++err > 0)
	    exit(-1);
	}
    }
  TMP_FREE;
}
#endif // CHECK

#endif // defined(CHECK)

