/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.testing;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.transforms.Combine;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.Assert;

public class CombineFnTester {
    public static <InputT, AccumT, OutputT> void testCombineFn(Combine.CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, OutputT expected) {
        CombineFnTester.testCombineFn(fn, input, Matchers.is(expected));
        Collections.shuffle(input);
        CombineFnTester.testCombineFn(fn, input, Matchers.is(expected));
    }

    public static <InputT, AccumT, OutputT> void testCombineFn(Combine.CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {
        int size = input.size();
        CombineFnTester.checkCombineFnShardsMultipleOrders(fn, Collections.singletonList(input), matcher);
        CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardEvenly(input, 2), matcher);
        if (size > 4) {
            CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardEvenly(input, size / 2), matcher);
            CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardEvenly(input, (int)((double)size / Math.sqrt(size))), matcher);
        }
        CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardExponentially(input, 1.4), matcher);
        CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardExponentially(input, 2.0), matcher);
        CombineFnTester.checkCombineFnShardsMultipleOrders(fn, CombineFnTester.shardExponentially(input, Math.E), matcher);
    }

    private static <InputT, AccumT, OutputT> void checkCombineFnShardsMultipleOrders(Combine.CombineFn<InputT, AccumT, OutputT> fn, List<? extends Iterable<InputT>> shards, Matcher<? super OutputT> matcher) {
        CombineFnTester.checkCombineFnShardsSingleMerge(fn, shards, matcher);
        CombineFnTester.checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
        CombineFnTester.checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
        Collections.shuffle(shards);
        CombineFnTester.checkCombineFnShardsSingleMerge(fn, shards, matcher);
        CombineFnTester.checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
        CombineFnTester.checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
    }

    private static <InputT, AccumT, OutputT> void checkCombineFnShardsSingleMerge(Combine.CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards, Matcher<? super OutputT> matcher) {
        List<AccumT> accumulators = CombineFnTester.combineInputs(fn, shards);
        AccumT merged = fn.mergeAccumulators(accumulators);
        Assert.assertThat(fn.extractOutput(merged), matcher);
    }

    private static <InputT, AccumT, OutputT> void checkCombineFnShardsWithEmptyAccumulators(Combine.CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards, Matcher<? super OutputT> matcher) {
        List<AccumT> accumulators = CombineFnTester.combineInputs(fn, shards);
        accumulators.add(0, fn.createAccumulator());
        accumulators.add(fn.createAccumulator());
        AccumT merged = fn.mergeAccumulators(accumulators);
        Assert.assertThat(fn.extractOutput(merged), matcher);
    }

    private static <InputT, AccumT, OutputT> void checkCombineFnShardsIncrementalMerging(Combine.CombineFn<InputT, AccumT, OutputT> fn, List<? extends Iterable<InputT>> shards, Matcher<? super OutputT> matcher) {
        Object accumulator = shards.isEmpty() ? (Object)fn.createAccumulator() : null;
        for (AccumT inputAccum : CombineFnTester.combineInputs(fn, shards)) {
            accumulator = accumulator == null ? (Object)inputAccum : fn.mergeAccumulators(Arrays.asList(accumulator, inputAccum));
            fn.extractOutput(accumulator);
        }
        Assert.assertThat(fn.extractOutput(accumulator), matcher);
    }

    private static <InputT, AccumT, OutputT> List<AccumT> combineInputs(Combine.CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards) {
        ArrayList<AccumT> accumulators = new ArrayList<AccumT>();
        int maybeCompact = 0;
        for (Iterable<InputT> shard : shards) {
            AccumT accumulator = fn.createAccumulator();
            for (InputT elem : shard) {
                accumulator = fn.addInput(accumulator, elem);
            }
            if (maybeCompact++ % 2 == 0) {
                accumulator = fn.compact(accumulator);
            }
            accumulators.add(accumulator);
        }
        return accumulators;
    }

    private static <T> List<List<T>> shardEvenly(List<T> input, int numShards) {
        ArrayList<List<T>> shards = new ArrayList<List<T>>(numShards);
        for (int i = 0; i < numShards; ++i) {
            shards.add(input.subList(i * input.size() / numShards, (i + 1) * input.size() / numShards));
        }
        return shards;
    }

    private static <T> List<List<T>> shardExponentially(List<T> input, double base) {
        assert (base > 1.0);
        ArrayList<List<T>> shards = new ArrayList<List<T>>();
        int end = input.size();
        while (end > 0) {
            int start = (int)((double)end / base);
            shards.add(input.subList(start, end));
            end = start;
        }
        return shards;
    }
}

