fredrikj.net / blog /
Further optimizing bigfloat matrix multiplication
October 17, 2024
How many tricks can we incorporate in multiprecision matrix multiplication?
The nfloat and nfloat_complex types in FLINT have recently gotten fast matrix multiplication code (improving on the already-decent classical matrix multiplication reported in the last post). As a result, here is the current speedup using nfloat_complex instead of acf as the floating-point ground type when calling gr_mat_nonsingular_solve to solve a linear system $Ax = b$ where $A$ is a random complex $n \times n$ matrix:
prec \ n | 2 | 4 | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|
64 | 1.906 | 2.232 | 2.265 | 2.125 | 1.867 | 1.766 | 1.800 | 1.746 | 1.622 | 1.330 |
128 | 1.581 | 2.104 | 2.235 | 2.471 | 2.096 | 1.992 | 2.059 | 1.854 | 1.610 | 1.363 |
192 | 1.823 | 2.181 | 3.439 | 4.088 | 3.573 | 3.453 | 3.575 | 2.948 | 2.158 | 1.870 |
256 | 1.679 | 2.043 | 2.864 | 3.326 | 3.045 | 3.121 | 3.392 | 2.605 | 1.785 | 1.634 |
320 | 1.534 | 1.757 | 2.659 | 2.991 | 2.845 | 2.998 | 3.172 | 2.878 | 2.056 | 1.633 |
384 | 1.620 | 1.748 | 2.134 | 2.496 | 2.533 | 2.813 | 2.927 | 2.078 | 1.921 | 1.682 |
512 | 1.596 | 1.808 | 2.151 | 2.506 | 2.421 | 2.648 | 3.094 | 2.617 | 2.065 | 1.717 |
1024 | 3.098 | 1.818 | 1.859 | 2.019 | 2.291 | 2.730 | 2.741 | 2.439 | 1.976 | 1.678 |
1536 | 2.550 | 1.933 | 1.895 | 1.946 | 2.185 | 2.469 | 2.465 | 2.243 | 1.888 | 1.626 |
2048 | 2.195 | 2.390 | 2.005 | 2.135 | 2.320 | 2.549 | 2.402 | 2.038 | 1.712 | 1.561 |
2560 | 2.201 | 2.164 | 1.902 | 2.072 | 2.265 | 2.656 | 2.527 | 2.163 | 1.862 | 1.667 |
3072 | 1.745 | 1.830 | 1.803 | 1.974 | 2.236 | 2.508 | 2.444 | 2.073 | 1.779 | 1.588 |
4096 | 1.481 | 1.519 | 1.587 | 1.785 | 2.128 | 2.335 | 2.745 | 2.433 | 2.115 | 1.855 |
There are quite significant ranges where we now get more than a 2x speedup over acf, and in some cases more than 3x.
For large enough $n$ (greater than some cutoff which varies roughly between 50 and 500 depending on the precision), nfloat and nfloat_complex convert to multimodular integer matrix multiplication, the same trick as arf and acf, so much of the improvement seen in the table comes from having better matrix multiplication for small and medium $n$.
Here, I've tried to put together all the time-saving tricks I'm aware of. This amalgation of tricks deserves a catchy name, so let's call it the Fixed-point-delayed-normalization-truncating-Karatsuba-Mulders-Winograd-Waksman-Strassen-Bodrato-Karatsuba algorithm. Let's break it down:
- We convert internally from floating-point numbers to fixed-point numbers. This speeds up additions and normalizations.
- If the precision is only a few limbs, we use direct truncating multiplication of fixed-point numbers (using inline assembly for extremely few limbs and Albin Ahlbäck's assembly routines otherwise). If the precision is many limbs, we use Mulders's recursive mulhigh together with Karatsuba integer multiplication.
- Some carries and sign adjustments are postponed to the end in the evaluation of dot products.
- When the precision is several limbs, we use Waksman's improvement of Winograd's trick to reduce basecase matrix multiplication from $n^3$ multiplications and $n^3$ additions to $0.5 n^3 + O(n^2)$ multiplications and $1.5 n^3 + O(n^2)$ additions.
- We use Strassen multiplications with Bodrato's improved scheduling to reduce sufficiently large $2n \times 2n$ matrix multiplications to seven $n \times n$ matrix multiplications (this is recursed until $n$ is small enough where we use the basecase).
- Finally, to multiply complex matrices, we use the Karatsuba trick to reduce four real matrix multiplications to three.
How much further can we optimize matrix multiplication? Well, there are still a few stones unturned:
- It should be possible to squeeze out some more performance by adding more assembly routines for the basic operations on fixed-point numbers.
- Instead of using mpn-compatible fixed-point numbers, it should be faster to switch to a SIMD-friendly representation with many nail bits (especially on AVX512 with IFMA).
- Finally, FFT representation would make sense for extremely high precision; it is conceivable that one could gain something already at 4096 bits (though I'm skeptical).
- The multimodular multiplication in FLINT is far from optimal; improving it should reduce the range of $n$ where it makes sense to do fixed-point math.
- In the benchmark, the matrices are nearly uniform, so fixed-point arithmetic performs optimally. The scaling/splitting strategy for poorly scaled matrices can certainly be improved.
Most of those tasks are for the distant future, however. A closer goal is to use the new matrix multiplication code to speed up arb, acb, arf and acf matrices as well. Even this will require a fair bit of work, in order to obtain rock-solid error bounds.
fredrikj.net | Blog index | RSS feed | Follow me on Mastodon | Become a sponsor