package org.apache.pig.backend.hadoop.executionengine.spark.converter;

import com.google.common.base.Optional;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.PhysicalOperator;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.plans.PhysicalPlan;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POLocalRearrange;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POSkewedJoin;
import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.builtin.PartitionSkewedKeys;
import org.apache.pig.impl.plan.NodeIdGenerator;
import org.apache.pig.impl.plan.OperatorKey;
import org.apache.pig.impl.plan.PlanException;
import org.apache.pig.impl.util.MultiMap;
import org.apache.pig.impl.util.Pair;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.class */
public class SkewedJoinConverter implements RDDConverter<Tuple, Tuple, POSkewedJoin>, Serializable {
    private static Log log = LogFactory.getLog(SkewedJoinConverter.class);
    private POLocalRearrange[] LRs;
    private POSkewedJoin poSkewedJoin;
    private String skewedJoinPartitionFile;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$PartitionIndexedKey.class */
    public static class PartitionIndexedKey extends IndexedKey {
        int partitionId;

        public PartitionIndexedKey(byte b, Object obj) {
            super(b, obj);
            this.partitionId = -1;
        }

        public PartitionIndexedKey(byte b, Object obj, int i) {
            super(b, obj);
            this.partitionId = i;
        }

        public int getPartitionId() {
            return this.partitionId;
        }

        private void setPartitionId(int i) {
            this.partitionId = i;
        }

        @Override // org.apache.pig.backend.hadoop.executionengine.spark.converter.IndexedKey
        public String toString() {
            return "PartitionIndexedKey{index=" + ((int) getIndex()) + ", partitionId=" + getPartitionId() + ", key=" + getKey() + '}';
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$SkewPartitionIndexKeyFunction.class */
    public static class SkewPartitionIndexKeyFunction extends AbstractFunction1<Tuple, Tuple2<PartitionIndexedKey, Tuple>> implements Serializable {
        private final SkewedJoinConverter poSkewedJoin;
        private final Broadcast<List<Tuple>> keyDist;
        private final Integer defaultParallelism;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;
        private transient Map<Tuple, Integer> currentIndexMap;
        private transient boolean initialized = false;
        private transient Integer parallelism = -1;

        public SkewPartitionIndexKeyFunction(SkewedJoinConverter skewedJoinConverter, Broadcast<List<Tuple>> broadcast, Integer num) {
            this.poSkewedJoin = skewedJoinConverter;
            this.keyDist = broadcast;
            this.defaultParallelism = num;
        }

        public Tuple2<PartitionIndexedKey, Tuple> apply(Tuple tuple) {
            this.poSkewedJoin.LRs[0].attachInput(tuple);
            try {
                Result nextTuple = this.poSkewedJoin.LRs[0].getNextTuple();
                Byte b = (Byte) ((Tuple) nextTuple.result).get(0);
                Tuple tuple2 = (Tuple) ((Tuple) nextTuple.result).get(1);
                return new Tuple2<>(new PartitionIndexedKey(b.byteValue(), tuple2, getPartitionId(tuple2).intValue()), tuple);
            } catch (Exception e) {
                System.out.print(e);
                return null;
            }
        }

        private Integer getPartitionId(Tuple tuple) {
            if (!this.initialized) {
                Integer[] numArr = new Integer[1];
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution(this.keyDist, numArr);
                this.parallelism = numArr[0];
                if (this.parallelism.intValue() <= 0) {
                    this.parallelism = this.defaultParallelism;
                }
                this.currentIndexMap = Maps.newHashMap();
                this.initialized = true;
            }
            Integer num = -1;
            Pair<Integer, Integer> pair = this.reducerMap.get(tuple);
            if (pair == null) {
                return -1;
            }
            if (this.currentIndexMap.containsKey(tuple)) {
                num = this.currentIndexMap.get(tuple);
            }
            Integer valueOf = (num.intValue() >= pair.first.intValue() + pair.second.intValue() || num.intValue() == -1) ? pair.first : Integer.valueOf(num.intValue() + 1);
            this.currentIndexMap.put(tuple, valueOf);
            return Integer.valueOf(valueOf.intValue() % this.parallelism.intValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$SkewedJoinPartitioner.class */
    public static class SkewedJoinPartitioner extends Partitioner {
        private int numPartitions;

        public SkewedJoinPartitioner(int i) {
            this.numPartitions = i;
        }

        public int numPartitions() {
            return this.numPartitions;
        }

        public int getPartition(Object obj) {
            int partitionId;
            if ((obj instanceof PartitionIndexedKey) && (partitionId = ((PartitionIndexedKey) obj).getPartitionId()) >= 0) {
                return partitionId;
            }
            int hashCode = ((Tuple) ((PartitionIndexedKey) obj).getKey()).hashCode() % this.numPartitions;
            return hashCode >= 0 ? hashCode : hashCode + this.numPartitions;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$StreamPartitionIndexKeyFunction.class */
    public static class StreamPartitionIndexKeyFunction implements FlatMapFunction<Tuple, Tuple2<PartitionIndexedKey, Tuple>> {
        private SkewedJoinConverter poSkewedJoin;
        private final Broadcast<List<Tuple>> keyDist;
        private final Integer defaultParallelism;
        private transient boolean initialized = false;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;
        private transient Integer parallelism;

        public StreamPartitionIndexKeyFunction(SkewedJoinConverter skewedJoinConverter, Broadcast<List<Tuple>> broadcast, Integer num) {
            this.poSkewedJoin = skewedJoinConverter;
            this.keyDist = broadcast;
            this.defaultParallelism = num;
        }

        public Iterable<Tuple2<PartitionIndexedKey, Tuple>> call(Tuple tuple) throws Exception {
            if (!this.initialized) {
                Integer[] numArr = new Integer[1];
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution(this.keyDist, numArr);
                this.parallelism = numArr[0];
                if (this.parallelism.intValue() <= 0) {
                    this.parallelism = this.defaultParallelism;
                }
                this.initialized = true;
            }
            this.poSkewedJoin.LRs[1].attachInput(tuple);
            Result nextTuple = this.poSkewedJoin.LRs[1].getNextTuple();
            Byte b = (Byte) ((Tuple) nextTuple.result).get(0);
            Tuple tuple2 = (Tuple) ((Tuple) nextTuple.result).get(1);
            ArrayList arrayList = new ArrayList();
            Pair<Integer, Integer> pair = this.reducerMap.get(tuple2);
            if (pair == null) {
                pair = new Pair<>(-1, 0);
            }
            Integer num = pair.first;
            for (Integer num2 = 0; num2.intValue() <= pair.second.intValue(); num2 = Integer.valueOf(num2.intValue() + 1)) {
                if (num.intValue() >= this.parallelism.intValue()) {
                    num = 0;
                }
                arrayList.add(new Tuple2(new PartitionIndexedKey(b.byteValue(), tuple2, num.intValue()), tuple));
                num = Integer.valueOf(num.intValue() + 1);
            }
            return arrayList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$ToValueFunction.class */
    public static class ToValueFunction<L, R> implements FlatMapFunction<Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>>, Tuple>, Serializable {
        private boolean[] innerFlags;
        private int[] schemaSize;
        private final Broadcast<List<Tuple>> keyDist;
        private transient boolean initialized = false;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter$ToValueFunction$Tuple2TransformIterable.class */
        public class Tuple2TransformIterable implements Iterable<Tuple> {
            Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> in;

            Tuple2TransformIterable(Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> it) {
                this.in = it;
            }

            @Override // java.lang.Iterable
            public Iterator<Tuple> iterator() {
                return new IteratorTransform<Tuple2<PartitionIndexedKey, Tuple2<L, R>>, Tuple>(this.in) { // from class: org.apache.pig.backend.hadoop.executionengine.spark.converter.SkewedJoinConverter.ToValueFunction.Tuple2TransformIterable.1
                    /* JADX INFO: Access modifiers changed from: protected */
                    @Override // org.apache.pig.backend.hadoop.executionengine.spark.converter.IteratorTransform
                    public Tuple transform(Tuple2<PartitionIndexedKey, Tuple2<L, R>> tuple2) {
                        try {
                            Object obj = ((Tuple2) tuple2._2)._1;
                            Object obj2 = ((Tuple2) tuple2._2)._2;
                            TupleFactory tupleFactory = TupleFactory.getInstance();
                            Tuple newTuple = tupleFactory.newTuple();
                            Tuple newTuple2 = tupleFactory.newTuple();
                            if (ToValueFunction.this.innerFlags[0]) {
                                newTuple2 = (Tuple) obj;
                            } else {
                                Optional optional = (Optional) obj;
                                if (optional.isPresent()) {
                                    newTuple2 = (Tuple) optional.get();
                                } else {
                                    if (!ToValueFunction.this.isFirstReduceKey((PartitionIndexedKey) tuple2._1)) {
                                        return next();
                                    }
                                    for (int i = 0; i < ToValueFunction.this.schemaSize[0]; i++) {
                                        newTuple2.append(null);
                                    }
                                }
                            }
                            for (int i2 = 0; i2 < newTuple2.size(); i2++) {
                                newTuple.append(newTuple2.get(i2));
                            }
                            Tuple newTuple3 = tupleFactory.newTuple();
                            if (ToValueFunction.this.innerFlags[1]) {
                                newTuple3 = (Tuple) obj2;
                            } else {
                                Optional optional2 = (Optional) obj2;
                                if (optional2.isPresent()) {
                                    newTuple3 = (Tuple) optional2.get();
                                } else {
                                    for (int i3 = 0; i3 < ToValueFunction.this.schemaSize[1]; i3++) {
                                        newTuple3.append(null);
                                    }
                                }
                            }
                            for (int i4 = 0; i4 < newTuple3.size(); i4++) {
                                newTuple.append(newTuple3.get(i4));
                            }
                            if (SkewedJoinConverter.log.isDebugEnabled()) {
                                SkewedJoinConverter.log.debug("MJC: Result = " + newTuple.toDelimitedString(" "));
                            }
                            return newTuple;
                        } catch (Exception e) {
                            SkewedJoinConverter.log.warn(e);
                            return null;
                        }
                    }
                };
            }
        }

        public ToValueFunction(boolean[] zArr, int[] iArr, Broadcast<List<Tuple>> broadcast) {
            this.innerFlags = zArr;
            this.schemaSize = iArr;
            this.keyDist = broadcast;
        }

        public Iterable<Tuple> call(Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> it) {
            return new Tuple2TransformIterable(it);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isFirstReduceKey(PartitionIndexedKey partitionIndexedKey) {
            if (partitionIndexedKey.getPartitionId() == -1) {
                return true;
            }
            if (!this.initialized) {
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution(this.keyDist, new Integer[1]);
                this.initialized = true;
            }
            Pair<Integer, Integer> pair = this.reducerMap.get(partitionIndexedKey.getKey());
            return pair == null || partitionIndexedKey.getPartitionId() == pair.first.intValue();
        }
    }

    public void setSkewedJoinPartitionFile(String str) {
        this.skewedJoinPartitionFile = str;
    }

    @Override // org.apache.pig.backend.hadoop.executionengine.spark.converter.RDDConverter
    public RDD<Tuple> convert(List<RDD<Tuple>> list, POSkewedJoin pOSkewedJoin) throws IOException {
        SparkUtil.assertPredecessorSize(list, pOSkewedJoin, 2);
        this.LRs = new POLocalRearrange[2];
        this.poSkewedJoin = pOSkewedJoin;
        createJoinPlans(pOSkewedJoin.getJoinPlans());
        RDD<Tuple> rdd = list.get(0);
        RDD<Tuple> rdd2 = list.get(1);
        SparkPigContext.get();
        Broadcast<List<Tuple>> broadcast = SparkPigContext.getBroadcastedVars().get(this.skewedJoinPartitionFile);
        SparkPigContext.get();
        Integer valueOf = Integer.valueOf(SparkPigContext.getParallelism(list, pOSkewedJoin));
        return doJoin(new JavaPairRDD<>(rdd.map(new SkewPartitionIndexKeyFunction(this, broadcast, valueOf), SparkUtil.getTuple2Manifest()), SparkUtil.getManifest(PartitionIndexedKey.class), SparkUtil.getManifest(Tuple.class)), new JavaPairRDD<>(rdd2.toJavaRDD().flatMap(new StreamPartitionIndexKeyFunction(this, broadcast, valueOf)).rdd(), SparkUtil.getManifest(PartitionIndexedKey.class), SparkUtil.getManifest(Tuple.class)), buildPartitioner(broadcast, valueOf), broadcast).rdd();
    }

    private void createJoinPlans(MultiMap<PhysicalOperator, PhysicalPlan> multiMap) throws PlanException {
        int i = -1;
        for (PhysicalOperator physicalOperator : multiMap.keySet()) {
            i++;
            POLocalRearrange pOLocalRearrange = new POLocalRearrange(genKey());
            try {
                pOLocalRearrange.setIndex(i);
                pOLocalRearrange.setResultType((byte) 110);
                pOLocalRearrange.setKeyType((byte) 110);
                pOLocalRearrange.setPlans(multiMap.get(physicalOperator));
                this.LRs[i] = pOLocalRearrange;
            } catch (ExecException e) {
                throw new PlanException(e.getMessage(), e.getErrorCode(), e.getErrorSource(), e);
            }
        }
    }

    private OperatorKey genKey() {
        return new OperatorKey(this.poSkewedJoin.getOperatorKey().scope, NodeIdGenerator.getGenerator().getNextNodeId(this.poSkewedJoin.getOperatorKey().scope));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<Tuple, Pair<Integer, Integer>> loadKeyDistribution(Broadcast<List<Tuple>> broadcast, Integer[] numArr) {
        HashMap hashMap = new HashMap();
        numArr[0] = -1;
        if (broadcast == null || broadcast.value() == null || ((List) broadcast.value()).size() == 0) {
            log.warn("Empty dist file: ");
            return hashMap;
        }
        try {
            TupleFactory tupleFactory = TupleFactory.getInstance();
            Map map = (Map) ((Tuple) ((List) broadcast.value()).get(0)).get(0);
            DataBag dataBag = (DataBag) map.get(PartitionSkewedKeys.PARTITION_LIST);
            numArr[0] = Integer.valueOf("" + map.get(PartitionSkewedKeys.TOTAL_REDUCERS));
            for (Tuple tuple : dataBag) {
                Integer num = (Integer) tuple.get(tuple.size() - 1);
                Integer num2 = (Integer) tuple.get(tuple.size() - 2);
                if (num.intValue() < num2.intValue()) {
                    num = Integer.valueOf(numArr[0].intValue() + num.intValue());
                }
                Tuple newTuple = tupleFactory.newTuple();
                for (int i = 0; i < tuple.size() - 2; i++) {
                    newTuple.append(tuple.get(i));
                }
                hashMap.put(newTuple, new Pair(num2, Integer.valueOf(num.intValue() - num2.intValue())));
            }
        } catch (ExecException e) {
            log.warn(e.getMessage());
        }
        return hashMap;
    }

    private SkewedJoinPartitioner buildPartitioner(Broadcast<List<Tuple>> broadcast, Integer num) {
        Integer[] numArr = new Integer[1];
        loadKeyDistribution(broadcast, numArr);
        Integer num2 = numArr[0];
        if (num2.intValue() <= 0) {
            num2 = num;
        }
        return new SkewedJoinPartitioner(num2.intValue());
    }

    private JavaRDD<Tuple> doJoin(JavaPairRDD<PartitionIndexedKey, Tuple> javaPairRDD, JavaPairRDD<PartitionIndexedKey, Tuple> javaPairRDD2, SkewedJoinPartitioner skewedJoinPartitioner, Broadcast<List<Tuple>> broadcast) {
        boolean[] innerFlags = this.poSkewedJoin.getInnerFlags();
        int[] iArr = new int[2];
        iArr[0] = 0;
        iArr[1] = 0;
        for (int i = 0; i < 2; i++) {
            if (this.poSkewedJoin.getSchema(i) != null) {
                iArr[i] = this.poSkewedJoin.getSchema(i).size();
            }
        }
        ToValueFunction toValueFunction = new ToValueFunction(innerFlags, iArr, broadcast);
        return (innerFlags[0] && innerFlags[1]) ? javaPairRDD.join(javaPairRDD2, skewedJoinPartitioner).mapPartitions(toValueFunction) : (!innerFlags[0] || innerFlags[1]) ? (innerFlags[0] || !innerFlags[1]) ? javaPairRDD.fullOuterJoin(javaPairRDD2, skewedJoinPartitioner).mapPartitions(toValueFunction) : javaPairRDD.rightOuterJoin(javaPairRDD2, skewedJoinPartitioner).mapPartitions(toValueFunction) : javaPairRDD.leftOuterJoin(javaPairRDD2, skewedJoinPartitioner).mapPartitions(toValueFunction);
    }
}
