@@ -48,7 +48,7 @@ final class FftMultiplier {
4848 /**
4949 * for FFTs of length up to 2^19
5050 */
51- private static final int ROOTS_CACHE2_SIZE = 20 ;
51+ private static final int ROOTS2_CACHE_SIZE = 20 ;
5252 /**
5353 * The threshold value for using 3-way Toom-Cook multiplication.
5454 */
@@ -58,14 +58,19 @@ final class FftMultiplier {
5858 * elements representing all (2^(k+2))-th roots between 0 and pi/2.
5959 * Used for FFT multiplication.
6060 */
61- private volatile static ComplexVector [] ROOTS2_CACHE = new ComplexVector [ROOTS_CACHE2_SIZE ];
61+ private volatile static ComplexVector [] ROOTS2_CACHE = new ComplexVector [ROOTS2_CACHE_SIZE ];
6262 /**
6363 * Sets of complex roots of unity. The set at index k contains 3*2^k
6464 * elements representing all (3*2^(k+2))-th roots between 0 and pi/2.
6565 * Used for FFT multiplication.
6666 */
6767 private volatile static ComplexVector [] ROOTS3_CACHE = new ComplexVector [ROOTS3_CACHE_SIZE ];
6868
69+ private static final ComplexVector ONE ;
70+ static {
71+ ONE = new ComplexVector (1 );
72+ ONE .set (0 , 1.0 , 0.0 );
73+ }
6974 /**
7075 * Returns the maximum number of bits that one double precision number can fit without
7176 * causing the multiplication to be incorrect.
@@ -118,10 +123,7 @@ static int bitsPerFftPoint(int bitLen) {
118123 */
119124 private static ComplexVector calculateRootsOfUnity (int n ) {
120125 if (n == 1 ) {
121- ComplexVector v = new ComplexVector (1 );
122- v .real (0 , 1 );
123- v .imag (0 , 0 );
124- return v ;
126+ return ONE ;
125127 }
126128 ComplexVector roots = new ComplexVector (n );
127129 roots .set (0 , 1.0 , 0.0 );
@@ -139,6 +141,36 @@ private static ComplexVector calculateRootsOfUnity(int n) {
139141 return roots ;
140142 }
141143
144+ private static ComplexVector calculateRootsOfUnity (int n , ComplexVector prev ) {
145+ if (n == 1 ) {
146+ return ONE ;
147+ }
148+ ComplexVector roots = new ComplexVector (n );
149+ roots .set (0 , 1.0 , 0.0 );
150+ double cos = COS_0_25 ;
151+ double sin = SIN_0_25 ;
152+ roots .set (n / 2 , cos , sin );
153+
154+ double angleTerm = 0.5 * Math .PI / n ;
155+ int ratio = n / prev .length ;
156+ for (int i = 1 , j = 1 ; j < n / 2 ; i ++, j += ratio ) {
157+ for (int k = 0 ; k < ratio - 1 ; k ++) {
158+ int outIdx = j + k ;
159+ double angle = angleTerm * outIdx ;
160+ cos = Math .cos (angle );
161+ sin = Math .sin (angle );
162+ roots .set (outIdx , cos , sin );
163+ roots .set (n - outIdx , sin , cos );
164+ }
165+ cos = prev .real (i );
166+ sin = prev .imag (i );
167+ int outIdx = j + ratio - 1 ;
168+ roots .set (outIdx , cos , sin );
169+ roots .set (n - outIdx , sin , cos );
170+ }
171+ return roots ;
172+ }
173+
142174 /**
143175 * Performs an FFT of length 2^n on the vector {@code a}.
144176 * This is a decimation-in-frequency implementation.
@@ -348,21 +380,33 @@ static BigInteger fromFftVector(ComplexVector fftVec, int signum, int bitsPerFft
348380 *
349381 * @param logN for a transform of length 2^logN
350382 */
351- private static ComplexVector [] getRootsOfUnity2 (int logN ) {
383+ static ComplexVector [] getRootsOfUnity2 (int logN ) {
352384 ComplexVector [] roots = new ComplexVector [logN + 1 ];
353- for (int i = logN ; i >= 0 ; i - = 2 ) {
354- if (i < ROOTS_CACHE2_SIZE ) {
385+ for (int i = logN % 2 ; i <= logN ; i + = 2 ) {
386+ if (i < ROOTS2_CACHE_SIZE ) {
355387 if (ROOTS2_CACHE [i ] == null ) {
356- ROOTS2_CACHE [i ] = calculateRootsOfUnity ( 1 << i );
388+ ROOTS2_CACHE [i ] = getRootOfUnity ( 1 , i , ROOTS2_CACHE );
357389 }
358390 roots [i ] = ROOTS2_CACHE [i ];
359391 } else {
360- roots [i ] = calculateRootsOfUnity ( 1 << i );
392+ roots [i ] = getRootOfUnity ( 1 , i , ROOTS2_CACHE );
361393 }
362394 }
363395 return roots ;
364396 }
365397
398+ private static ComplexVector getRootOfUnity (int b , int e , ComplexVector [] roots ) {
399+ int nearest = floorEntry (e , roots );
400+ return nearest >= 2
401+ ? calculateRootsOfUnity (b << e , roots [nearest ])
402+ : calculateRootsOfUnity (b << e );
403+ }
404+
405+ private static int floorEntry (int i , ComplexVector [] roots ) {
406+ while (i >= 2 && roots [i ] == null ) { i --; }
407+ return i ;
408+ }
409+
366410 /**
367411 * Returns sets of complex roots of unity. For k=logN, logN-2, logN-4, ...,
368412 * the return value contains all k-th roots between 0 and pi/2.
@@ -372,11 +416,11 @@ private static ComplexVector[] getRootsOfUnity2(int logN) {
372416 private static ComplexVector getRootsOfUnity3 (int logN ) {
373417 if (logN < ROOTS3_CACHE_SIZE ) {
374418 if (ROOTS3_CACHE [logN ] == null ) {
375- ROOTS3_CACHE [logN ] = calculateRootsOfUnity ( 3 << logN );
419+ ROOTS3_CACHE [logN ] = getRootOfUnity ( 3 , logN , ROOTS3_CACHE );
376420 }
377421 return ROOTS3_CACHE [logN ];
378422 } else {
379- return calculateRootsOfUnity ( 3 << logN );
423+ return getRootOfUnity ( 3 , logN , ROOTS3_CACHE );
380424 }
381425 }
382426
0 commit comments