/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.chaining;

import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.FlatMapDriver;
import org.apache.flink.runtime.operators.FlatMapTaskTest;
import org.apache.flink.runtime.operators.ReduceTaskTest;
import org.apache.flink.runtime.operators.chaining.SynchronousChainedCombineDriver;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.testutils.TaskTestBase;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;

@RunWith(value=PowerMockRunner.class)
@PrepareForTest(value={Task.class, ResultPartitionWriter.class})
public class ChainTaskTest
extends TaskTestBase {
    private static final int MEMORY_MANAGER_SIZE = 0x300000;
    private static final int NETWORK_BUFFER_SIZE = 1024;
    private final List<Record> outList = new ArrayList<Record>();
    private final RecordComparatorFactory compFact = new RecordComparatorFactory(new int[]{0}, new Class[]{IntValue.class}, new boolean[]{true});
    private final RecordSerializerFactory serFact = RecordSerializerFactory.get();

    @Test
    public void testMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        double memoryFraction = 1.0;
        try {
            this.initEnvironment(0x300000L, 1024);
            this.addInput(new UniformRecordGenerator(100, 20, false), 0);
            this.addOutput(this.outList);
            TaskConfig combineConfig = new TaskConfig(new Configuration());
            combineConfig.addInputToGroup(0);
            combineConfig.setInputSerializer((TypeSerializerFactory)this.serFact, 0);
            combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            combineConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 0);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 1);
            combineConfig.setRelativeMemoryDriver(1.0);
            combineConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(ReduceTaskTest.MockCombiningReduceStub.class));
            this.getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
            BatchTask testTask = new BatchTask();
            this.registerTask((AbstractInvokable)testTask, FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            try {
                testTask.invoke();
            }
            catch (Exception e) {
                e.printStackTrace();
                Assert.fail((String)"Invoke method caused exception.");
            }
            Assert.assertEquals((long)100L, (long)this.outList.size());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testFailingMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        long memorySize = 0x300000L;
        int bufferSize = 1038336;
        double memoryFraction = 1.0;
        try {
            this.initEnvironment(0x300000L, 1038336);
            this.addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
            this.addOutput(this.outList);
            TaskConfig combineConfig = new TaskConfig(new Configuration());
            combineConfig.addInputToGroup(0);
            combineConfig.setInputSerializer((TypeSerializerFactory)this.serFact, 0);
            combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            combineConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 0);
            combineConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 1);
            combineConfig.setRelativeMemoryDriver(1.0);
            combineConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(MockFailingCombineStub.class));
            this.getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
            BatchTask testTask = new BatchTask();
            super.registerTask((AbstractInvokable)testTask, FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            boolean stubFailed = false;
            try {
                testTask.invoke();
            }
            catch (Exception e) {
                stubFailed = true;
            }
            Assert.assertTrue((String)"Function exception was not forwarded.", (boolean)stubFailed);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    public static final class MockFailingCombineStub
    implements GroupReduceFunction<Record, Record>,
    GroupCombineFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private int cnt = 0;

        public void reduce(Iterable<Record> records, Collector<Record> out) throws Exception {
            if (++this.cnt >= 5) {
                throw new RuntimeException("Expected Test Exception");
            }
            for (Record r : records) {
                out.collect((Object)r);
            }
        }

        public void combine(Iterable<Record> values, Collector<Record> out) throws Exception {
            this.reduce(values, out);
        }
    }
}

