/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.graph.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.broadcast.Broadcast;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.impl.graph.scoring.GraphFeedForwardWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

class GraphFeedForwardWithKeyFunctionAdapter<K>
implements FlatMapFunctionAdapter<Iterator<Tuple2<K, INDArray[]>>, Tuple2<K, INDArray[]>> {
    protected static Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final int batchSize;

    public GraphFeedForwardWithKeyFunctionAdapter(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.batchSize = batchSize;
    }

    public Iterable<Tuple2<K, INDArray[]>> call(Iterator<Tuple2<K, INDArray[]>> iterator) throws Exception {
        int firstIdx;
        int nextIdx;
        if (!iterator.hasNext()) {
            return Collections.emptyList();
        }
        ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).unsafeDuplication();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParams(val);
        ArrayList<Object> featuresList = new ArrayList<Object>(this.batchSize);
        ArrayList<Object> keyList = new ArrayList<Object>(this.batchSize);
        ArrayList<Integer> origSizeList = new ArrayList<Integer>();
        int[][] firstShapes = null;
        boolean sizesDiffer = false;
        int tupleCount = 0;
        while (iterator.hasNext()) {
            int i;
            Tuple2<K, INDArray[]> t2 = iterator.next();
            if (firstShapes == null) {
                firstShapes = new int[((INDArray[])t2._2()).length][0];
                for (i = 0; i < firstShapes.length; ++i) {
                    firstShapes[i] = ((INDArray[])t2._2())[i].shape();
                }
            } else if (!sizesDiffer) {
                block2: for (i = 0; i < firstShapes.length; ++i) {
                    for (int j = 1; j < firstShapes[i].length; ++j) {
                        if (firstShapes[i][j] == ((INDArray[])featuresList.get(tupleCount - 1))[i].size(j)) continue;
                        sizesDiffer = true;
                        continue block2;
                    }
                }
            }
            featuresList.add(t2._2());
            keyList.add(t2._1());
            origSizeList.add(((INDArray[])t2._2())[0].size(0));
            ++tupleCount;
        }
        if (tupleCount == 0) {
            return Collections.emptyList();
        }
        ArrayList<Tuple2<K, INDArray[]>> output = new ArrayList<Tuple2<K, INDArray[]>>(tupleCount);
        for (int currentArrayIndex = 0; currentArrayIndex < featuresList.size(); currentArrayIndex += nextIdx - firstIdx) {
            int j;
            int i;
            int i2;
            int examplesInBatch;
            INDArray[] f;
            firstIdx = currentArrayIndex;
            nextIdx = currentArrayIndex;
            ArrayList<INDArray[]> toMerge = new ArrayList<INDArray[]>();
            firstShapes = null;
            for (examplesInBatch = 0; nextIdx < featuresList.size() && examplesInBatch < this.batchSize; examplesInBatch += f[0].size(0), ++nextIdx) {
                f = (INDArray[])featuresList.get(nextIdx);
                if (firstShapes == null) {
                    firstShapes = new int[f.length][0];
                    for (i2 = 0; i2 < firstShapes.length; ++i2) {
                        firstShapes[i2] = f[i2].shape();
                    }
                } else if (sizesDiffer) {
                    boolean breakWhile = false;
                    block7: for (i = 0; i < firstShapes.length; ++i) {
                        for (j = 1; j < firstShapes[i].length; ++j) {
                            if (firstShapes[i][j] == ((INDArray[])featuresList.get(nextIdx))[i].size(j)) continue;
                            breakWhile = true;
                            continue block7;
                        }
                    }
                    if (breakWhile) break;
                }
                toMerge.add(f);
            }
            INDArray[] batchFeatures = new INDArray[((INDArray[])toMerge.get(0)).length];
            for (i2 = 0; i2 < batchFeatures.length; ++i2) {
                INDArray[] tempArr = new INDArray[toMerge.size()];
                for (j = 0; j < tempArr.length; ++j) {
                    tempArr[j] = ((INDArray[])toMerge.get(j))[i2];
                }
                batchFeatures[i2] = Nd4j.concat((int)0, (INDArray[])tempArr);
            }
            INDArray[] out = network.output(false, batchFeatures);
            examplesInBatch = 0;
            for (i = firstIdx; i < nextIdx; ++i) {
                int numExamples = (Integer)origSizeList.get(i);
                INDArray[] outSubset = new INDArray[out.length];
                for (int j2 = 0; j2 < out.length; ++j2) {
                    outSubset[j2] = this.getSubset(examplesInBatch, examplesInBatch + numExamples, out[j2]);
                }
                examplesInBatch += numExamples;
                output.add(new Tuple2(keyList.get(i), (Object)outSubset));
            }
        }
        Nd4j.getExecutioner().commit();
        return output;
    }

    private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) {
        switch (from.rank()) {
            case 2: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all()});
            }
            case 3: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all()});
            }
            case 4: {
                return from.get(new INDArrayIndex[]{NDArrayIndex.interval((int)exampleStart, (int)exampleEnd), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()});
            }
        }
        throw new RuntimeException("Invalid rank: " + from.rank());
    }
}

