(Factorial)
How do you compute factorial in parallel?
Well, folks propose to, you know, just split the range in half, and multiply the numbers in each half in a different thread.
But that is not it.
You may notice that "left" and "right" parts of the computation end up computing numbers of very different magnitude: the left multiplies only small numbers, the right multiplies only big numbers.
Can we split the numbers in such a way that both sides multiply roughly the same numbers?
Like, even and odd? Then can we split the even numbers into those that are divisible by 4, and not. Then those that are divisible by 4 into those that are divisible by 8 and not.
Of course, we don't need to test divisibility for this, we can just increase the step for enumeration, and that way we can extend this method to the odd numbers, too:
But that's not all.
Well, half of even numbers are 2x an odd number, and the other half are 2x an even number. So we can recursively reduce the problem to a multiplication of odd numbers.
Now, this starts to look interesting.
Aren't we duplicating work? Since we are multiplying only odd numbers, it seems we multiply the same numbers again and again. If we could split the odd numbers into subsequences, we should be able to multiply 1x3x5 separately, 7x9 separately, 11x13x15x17x19, etc, then reuse the results. We should end up with half the multiplications needed to compute factorial (since we only multiply odd numbers).
Now, looks like we are done.
Yay, splitting work into equally sized numbers pays off, and getting rid of even numbers is a good idea, too. Unfortunately, the results with half the multiplications does not result in better time.
Did I mess up?
ok, let's do that all single-threadedly. A win in the number of multiplications should show:
Ok, we do eliminate half the multiplications.
How do you compute factorial in parallel?
Well, folks propose to, you know, just split the range in half, and multiply the numbers in each half in a different thread.
new Fact(2 * 1024 * 1024).invoke();// compute like so
static class Fact extends RecursiveTask<BigInteger> {
int from;
int to;
public Fact(int n) {
this(0, n);
}
private Fact(int from, int to) {
this.from = from;
this.to = to;
}
@Override
protected BigInteger compute() {
if (from == to) {
return from == 0? BigInteger.ONE: BigInteger.valueOf(from);
}
int mid = (from + to) >>> 1;
ForkJoinTask<BigInteger> left = new Fact(from, mid).fork();
BigInteger right = new Fact(mid+1, to).invoke();
return right.multiply(left.join());
}
}But that is not it.
You may notice that "left" and "right" parts of the computation end up computing numbers of very different magnitude: the left multiplies only small numbers, the right multiplies only big numbers.
Can we split the numbers in such a way that both sides multiply roughly the same numbers?
Ensure even distribution of work
Like, even and odd? Then can we split the even numbers into those that are divisible by 4, and not. Then those that are divisible by 4 into those that are divisible by 8 and not.
Of course, we don't need to test divisibility for this, we can just increase the step for enumeration, and that way we can extend this method to the odd numbers, too:
static class Fact2 extends RecursiveTask<BigInteger> {
int from;
int to;
int step;
public Fact2(int n) {
this(0, n, 1);
}
private Fact2(int from, int to, int step) {
this.from = from;
this.to = to;
this.step = step;
}
@Override
protected BigInteger compute() {
if (from + step > to) {
return from < 2? BigInteger.ONE: BigInteger.valueOf(from);
}
ForkJoinTask<BigInteger> left = new Fact2(from, to, step << 1).fork();
BigInteger right = new Fact2(from + step, to, step << 1).invoke();
return right.multiply(left.join());
}
}
But that's not all.
Evens are odd
Well, half of even numbers are 2x an odd number, and the other half are 2x an even number. So we can recursively reduce the problem to a multiplication of odd numbers.
public static BigInteger fact3(int n) {
if (n == 0) {
return BigInteger.ONE;
}
Fact3 f = new Fact3(n);
BigInteger r = f.invoke();
return r.shiftLeft(f.shift);
}
static class Fact3 extends RecursiveTask<BigInteger> {
int from;
int to;
int step;
public int shift;
public Fact3(int n) {
this(1, n, 1, 0);
}
private Fact3(int from, int to, int step, int shift) {
this.from = from;
this.to = to;
this.step = step;
this.shift = shift;
}
@Override
protected BigInteger compute() {
if (from + step > to) {
return from < 2? BigInteger.ONE: BigInteger.valueOf(from);
}
if (from == 2 && step > 1) {
from = 1;
to >>>= 1;
step >>= 1;
shift++;
}
Fact3 lf = new Fact3(from, to, step << 1, shift);
ForkJoinTask<BigInteger> left = lf.fork();
Fact3 rf = new Fact3(from + step, to, step << 1, shift);
BigInteger right = rf.invoke().multiply(left.join());
shift = lf.shift + rf.shift;
return right;
}
}Now, this starts to look interesting.
Eliminate repetitions
Aren't we duplicating work? Since we are multiplying only odd numbers, it seems we multiply the same numbers again and again. If we could split the odd numbers into subsequences, we should be able to multiply 1x3x5 separately, 7x9 separately, 11x13x15x17x19, etc, then reuse the results. We should end up with half the multiplications needed to compute factorial (since we only multiply odd numbers).
public static BigInteger fact4(int n) {
BigInteger res = BigInteger.ONE;
if (n == 0) {
return res;
}
int[] sz = new int[32 - Integer.numberOfLeadingZeros(n)];
for(int i = sz.length; i-- > 0; n >>= 1) {
sz[i] = n;
}
int ds = 1;
ForkJoinTask<BigInteger> one = new Fact4(1, 1).fork();
int b = sz.length - 1;
int shifts = b;
ForkJoinTask<BigInteger>[] forks = new ForkJoinTask[sz.length-1];
for(int i = 1; i < sz.length; i++) {
int pow = b;
b -= 1;
int from = sz[i-1];
int to = sz[i];
from = (from & 1) == 1? from + 2: from + 1; // from is an odd number greater than the end of the previous range
to = (to & 1) == 1? to + 1: to;
ds += (to + 1 - from) >> 1; // how many digits will be computed
shifts += ds * b; // given where we are, these are all even - what power of 2 is skipped
Fact4 f = new Fact4(from , to);
forks[i-1] = from > to ? one: ForkJoinTask.adapt(() -> f.invoke().pow(pow));
}
return forks(forks, 0, forks.length, 1).invoke().shiftLeft(shifts); // roll up the forks
}
// this method is similar to Fact4 - preserving the magnitude of numbers seems important
static ForkJoinTask<BigInteger> forks(ForkJoinTask<BigInteger>[] fs, int from, int to, int step) {
if (from + step >= to) {
return fs[from];
}
ForkJoinTask<BigInteger> left = forks(fs, from, to, step << 1);
ForkJoinTask<BigInteger> right = forks(fs, from + step, to, step << 1);
return ForkJoinTask.adapt(() -> {
left.fork();
return right.invoke().multiply(left.join());
});
}
static class Fact4 extends RecursiveTask<BigInteger> {
int from;
int to;
int step;
public Fact4(int from, int to) {
this(from, to, 2);
}
private Fact4(int from, int to, int step) {
this.from = from;
this.to = to;
this.step = step;
}
@Override
protected BigInteger compute() {
if (from + step > to) {
return from < 2? BigInteger.ONE: BigInteger.valueOf(from);
}
ForkJoinTask<BigInteger> left = new Fact4(from, to, step << 1).fork();
return new Fact4(from + step, to, step << 1).invoke().multiply(left.join());
}
}Now, looks like we are done.
$ /Library/Java/JavaVirtualMachines/graalvm-ee-complete-java11-20.2.0/Contents/Home/bin/java a Fact straightforward: 4.984 Fact multiply equally sized nums: 4.754 Fact multiply only odd: 4.409 Fact multiply only odd, and only once: 5.603
Yay, splitting work into equally sized numbers pays off, and getting rid of even numbers is a good idea, too. Unfortunately, the results with half the multiplications does not result in better time.
Did I mess up?
Single-threaded
ok, let's do that all single-threadedly. A win in the number of multiplications should show:
public static BigInteger fact4_1(int n) {
BigInteger res = BigInteger.ONE;
if (n == 0) {
return res;
}
int[] sz = new int[32 - Integer.numberOfLeadingZeros(n)];
for(int i = sz.length; i-- > 0; n >>= 1) {
sz[i] = n;
}
int ds = 1;
ForkJoinTask<BigInteger> one = new Fact4(1, 1);
int b = sz.length - 1;
int shifts = b;
ForkJoinTask[] forks = new ForkJoinTask[sz.length-1];
for(int i = 1; i < sz.length; i++) {
int pow = b;
b -= 1;
int from = sz[i-1];
int to = sz[i];
from = (from & 1) == 1? from + 2: from + 1; // from is an odd number greater than the end of the previous range
to = (to & 1) == 1? to + 1: to;
ds += (to + 1 - from) >> 1; // how many digits will be computed
shifts += ds * b; // given where we are, these are all even - what power of 2 is skipped
forks[i-1] = from > to ? one: new Fact4(from , to);
}
BigInteger prev = BigInteger.ONE;
for(ForkJoinTask<BigInteger> f: forks) {
prev = prev.multiply(f.invoke());
res = res.multiply(prev);
}
return res.shiftLeft(shifts);
} $ /Library/Java/JavaVirtualMachines/graalvm-ee-complete-java11-20.2.0/Contents/Home/bin/java -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 a Fact straightforward: 7.365 Fact multiply equally sized nums: 7.084 Fact multiply only odd: 6.952 Fact multiply only odd, and only once: 8.291 Fact multiply only odd, and only once, single-threaded: 5.616
Ok, we do eliminate half the multiplications.
no subject
Date: 2021-11-27 03:24 pm (UTC)Wow, so nice, funny and impressive! Now the question is, how come we don't gain much. Especially the result #4, it spends too much time. Threading takes time.
no subject
Date: 2021-11-27 03:51 pm (UTC)There are few aspects to consider:
* multiplying "worse" numbers. (Too small vs too large; I can see that multiplying sequentially is worse than multiplying "every other number" like forks(), so it must be down to that)
* waiting for dependent chains (I can see all others go to 240% CPU, then gradually decrease to 180%; the version with fewer multiplications goes to 170% only, then goes down). So the CPU cost is lower, just it is not translated into time
* actually, the parallel version does pow() as well - that's an attempt to not wait for dependent chains at the cost of a bit more multiplications; the singlethreaded doesn't do that, and just keeps accumulating prev.
no subject
Date: 2021-11-27 04:34 pm (UTC)The seconds of execution time - that's because all of these are computing factorial of 2M. The single-threaded execution time explains how much the factorial actually takes. Therefore, I believe the difference between:
Fact straightforward: 7.365
Fact multiply equally sized nums: 7.084
is entirely due to what numbers we multiply together.
no subject
Date: 2021-11-27 05:43 pm (UTC)Right; it might make sense to group evenly, by the numbers length, starting with both ends.
no subject
Date: 2021-11-27 07:01 pm (UTC)Here's what I get:
def fact(n, xs=None): # the even-odd interlacing method if n == 0: return 1 shifts = [0] def f(b, e, s, sh, shifts, xs): if b == 2: b = b >> 1 e = e >> 1 s = s >> 1 sh += 1 if b + s > e: shifts[0] += sh return b x = f(b, e, s << 1, sh, shifts, xs) y = f(b + s, e, s << 1, sh, shifts, xs) r = x * y if xs is not None: xs.append((x.bit_length(), y.bit_length(), r.bit_length())) return r return f(1, n, 1, 0, shifts, xs) << shifts[0] print('-------------------') print('burn the candle on both ends:') ys = [] xs = [x for x in range(1, 2001)] l = len(xs) - 1 while l > 0: i = 0 while i < l - i: x = xs[i] y = xs[l - i] r = x * y xs[i] = r ys.append((x.bit_length(), y.bit_length(), r.bit_length())) i += 1 if i == l - i: l = i else: l = i - 1 xs = xs[0:l+1] xs.sort(key=lambda x: x.bit_length()) res_candle = xs[0] print('-------------------') print('even-odd:') xs = [] res_interlace = fact(2000, xs) xs.sort(key=lambda x: (x[2], x[0], x[1])) ys.sort(key=lambda x: (x[2], x[0], x[1])) for x, y in zip(xs, ys): print('%d * %d = %d' % (x[0] - y[0], x[1] - y[1], x[2] - y[2])) print(res_candle == res_interlace)Here the even-odd is a clear winner: most of the time its operands and the result of their product is much shorter - 10...50 bits shorter, to be precise. The tail of this print out looks like this:
well.... of course, the main difference is because the even-odd method doesn't use even numbers at all, and the "burning the candle from two ends" does. I am not sure how you'd like to apply the method, if we eliminate the even numbers.
If I use the even-odd method that doesn't eliminate even numbers, then the spread of bit sizes is similar to "burn the candle from both ends" method.
no subject
Date: 2021-11-27 10:22 pm (UTC)Good point; I'm trying to figure out how co combine both approaches.
no subject
Date: 2021-11-28 02:14 pm (UTC)Like, if we interlace, we end up multiplying 1x1M in the first pair, and 999999x2M in the last pair. If we somehow manage to burn the candle on both ends, we'll multiply 1x2M in the first pair and 999999x1M in the last pair. The biggest and the smallest products differ by only 1 bit between the two approaches.