/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.randomprojection;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.clustering.randomprojection.RPHyperPlanes;
import org.deeplearning4j.clustering.randomprojection.RPNode;
import org.deeplearning4j.clustering.randomprojection.RPTree;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Doubles;

public class RPUtils {
    private static ThreadLocal<Map<String, DifferentialFunction>> functionInstances = new ThreadLocal();

    public static <T extends DifferentialFunction> DifferentialFunction getOp(String name, INDArray x, INDArray y, INDArray result) {
        Map<String, DifferentialFunction> ops = functionInstances.get();
        if (ops == null) {
            ops = new HashMap<String, DifferentialFunction>();
            functionInstances.set(ops);
        }
        boolean allDistances = x.length() != y.length();
        switch (name) {
            case "cosinedistance": {
                if (!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
                    CosineDistance cosineDistance = new CosineDistance(x, y, result, allDistances, new int[0]);
                    ops.put(name, (DifferentialFunction)cosineDistance);
                    return cosineDistance;
                }
                CosineDistance cosineDistance = (CosineDistance)ops.get(name);
                return cosineDistance;
            }
            case "cosinesimilarity": {
                if (!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) {
                    CosineSimilarity cosineSimilarity = new CosineSimilarity(x, y, result, allDistances, new int[0]);
                    ops.put(name, (DifferentialFunction)cosineSimilarity);
                    return cosineSimilarity;
                }
                CosineSimilarity cosineSimilarity = (CosineSimilarity)ops.get(name);
                cosineSimilarity.setX(x);
                cosineSimilarity.setY(y);
                cosineSimilarity.setZ(result);
                return cosineSimilarity;
            }
            case "manhattan": {
                if (!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
                    ManhattanDistance manhattanDistance = new ManhattanDistance(x, y, result, allDistances, new int[0]);
                    ops.put(name, (DifferentialFunction)manhattanDistance);
                    return manhattanDistance;
                }
                ManhattanDistance manhattanDistance = (ManhattanDistance)ops.get(name);
                manhattanDistance.setX(x);
                manhattanDistance.setY(y);
                manhattanDistance.setZ(result);
                return manhattanDistance;
            }
            case "jaccard": {
                if (!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
                    JaccardDistance jaccardDistance = new JaccardDistance(x, y, result, allDistances, new int[0]);
                    ops.put(name, (DifferentialFunction)jaccardDistance);
                    return jaccardDistance;
                }
                JaccardDistance jaccardDistance = (JaccardDistance)ops.get(name);
                jaccardDistance.setX(x);
                jaccardDistance.setY(y);
                jaccardDistance.setZ(result);
                return jaccardDistance;
            }
            case "hamming": {
                if (!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
                    HammingDistance hammingDistance = new HammingDistance(x, y, result, allDistances, new int[0]);
                    ops.put(name, (DifferentialFunction)hammingDistance);
                    return hammingDistance;
                }
                HammingDistance hammingDistance = (HammingDistance)ops.get(name);
                hammingDistance.setX(x);
                hammingDistance.setY(y);
                hammingDistance.setZ(result);
                return hammingDistance;
            }
        }
        if (!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
            EuclideanDistance euclideanDistance = new EuclideanDistance(x, y, result, allDistances, new int[0]);
            ops.put(name, (DifferentialFunction)euclideanDistance);
            return euclideanDistance;
        }
        EuclideanDistance euclideanDistance = (EuclideanDistance)ops.get(name);
        euclideanDistance.setX(x);
        euclideanDistance.setY(y);
        euclideanDistance.setZ(result);
        return euclideanDistance;
    }

    public static List<Pair<Double, Integer>> queryAllWithDistances(INDArray toQuery, INDArray X, List<RPTree> trees, int n, String similarityFunction) {
        if (trees.isEmpty()) {
            throw new ND4JIllegalArgumentException("Trees is empty!");
        }
        List<Integer> candidates = RPUtils.getCandidates(toQuery, trees, similarityFunction);
        List<Pair<Double, Integer>> sortedCandidates = RPUtils.sortCandidates(toQuery, X, candidates, similarityFunction);
        int numReturns = Math.min(n, sortedCandidates.size());
        ArrayList<Pair<Double, Integer>> ret = new ArrayList<Pair<Double, Integer>>(numReturns);
        for (int i = 0; i < numReturns; ++i) {
            ret.add(sortedCandidates.get(i));
        }
        return ret;
    }

    public static INDArray queryAll(INDArray toQuery, INDArray X, List<RPTree> trees, int n, String similarityFunction) {
        if (trees.isEmpty()) {
            throw new ND4JIllegalArgumentException("Trees is empty!");
        }
        List<Integer> candidates = RPUtils.getCandidates(toQuery, trees, similarityFunction);
        List<Pair<Double, Integer>> sortedCandidates = RPUtils.sortCandidates(toQuery, X, candidates, similarityFunction);
        int numReturns = Math.min(n, sortedCandidates.size());
        INDArray result = Nd4j.create((int)numReturns);
        for (int i = 0; i < numReturns; ++i) {
            result.putScalar((long)i, ((Integer)sortedCandidates.get(i).getSecond()).intValue());
        }
        return result;
    }

    public static List<Pair<Double, Integer>> sortCandidates(INDArray x, INDArray X, List<Integer> candidates, String similarityFunction) {
        int prevIdx = -1;
        ArrayList<Pair<Double, Integer>> ret = new ArrayList<Pair<Double, Integer>>();
        int i = 0;
        while (i < candidates.size()) {
            if (candidates.get(i) != prevIdx) {
                ret.add((Pair<Double, Integer>)Pair.of((Object)RPUtils.computeDistance(similarityFunction, X.slice((long)candidates.get(i).intValue()), x), (Object)candidates.get(i)));
            }
            prevIdx = i++;
        }
        Collections.sort(ret, new Comparator<Pair<Double, Integer>>(){

            @Override
            public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) {
                return Doubles.compare((double)((Double)doubleIntegerPair.getFirst()), (double)((Double)t1.getFirst()));
            }
        });
        return ret;
    }

    public static INDArray getAllCandidates(INDArray x, List<RPTree> trees, String similarityFunction) {
        List<Integer> candidates = RPUtils.getCandidates(x, trees, similarityFunction);
        Collections.sort(candidates);
        int prevIdx = -1;
        int idxCount = 0;
        ArrayList<Pair> scores = new ArrayList<Pair>();
        int i = 0;
        while (i < candidates.size()) {
            if (candidates.get(i) == prevIdx) {
                ++idxCount;
            } else if (prevIdx != -1) {
                scores.add(Pair.of((Object)idxCount, (Object)prevIdx));
                idxCount = 1;
            }
            prevIdx = i++;
        }
        scores.add(Pair.of((Object)idxCount, (Object)prevIdx));
        INDArray arr = Nd4j.create((int)scores.size());
        for (int i2 = 0; i2 < scores.size(); ++i2) {
            arr.putScalar((long)i2, ((Integer)((Pair)scores.get(i2)).getSecond()).intValue());
        }
        return arr;
    }

    public static List<Integer> getCandidates(INDArray x, List<RPTree> roots, String similarityFunction) {
        LinkedHashSet<Integer> ret = new LinkedHashSet<Integer>();
        for (RPTree tree : roots) {
            RPNode root = tree.getRoot();
            RPNode query = RPUtils.query(root, tree.getRpHyperPlanes(), x, similarityFunction);
            ret.addAll(query.getIndices());
        }
        return new ArrayList<Integer>(ret);
    }

    public static RPNode query(RPNode from, RPHyperPlanes planes, INDArray x, String similarityFunction) {
        if (from.getLeft() == null && from.getRight() == null) {
            return from;
        }
        INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth());
        double dist = RPUtils.computeDistance(similarityFunction, x, hyperPlane);
        if (dist <= from.getMedian()) {
            return RPUtils.query(from.getLeft(), planes, x, similarityFunction);
        }
        return RPUtils.query(from.getRight(), planes, x, similarityFunction);
    }

    public static INDArray computeDistanceMulti(String function, INDArray x, INDArray y, INDArray result) {
        ReduceOp op = (ReduceOp)RPUtils.getOp(function, x, y, result);
        op.setDimensions(new int[]{1});
        Nd4j.getExecutioner().exec(op);
        return op.z();
    }

    public static double computeDistance(String function, INDArray x, INDArray y, INDArray result) {
        ReduceOp op = (ReduceOp)RPUtils.getOp(function, x, y, result);
        Nd4j.getExecutioner().exec(op);
        return op.z().getDouble(0L);
    }

    public static double computeDistance(String function, INDArray x, INDArray y) {
        return RPUtils.computeDistance(function, x, y, Nd4j.scalar((double)0.0));
    }

    public static void buildTree(RPTree tree, RPNode from, RPHyperPlanes planes, INDArray X, int maxSize, int depth, String similarityFunction) {
        double cosineSim;
        int i;
        if (from.getIndices().size() <= maxSize) {
            RPUtils.slimNode(from);
            return;
        }
        ArrayList<Double> distances = new ArrayList<Double>();
        RPNode left = new RPNode(tree, depth + 1);
        RPNode right = new RPNode(tree, depth + 1);
        if (planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) {
            planes.addRandomHyperPlane();
        }
        INDArray hyperPlane = planes.getHyperPlaneAt(depth);
        for (i = 0; i < from.getIndices().size(); ++i) {
            cosineSim = RPUtils.computeDistance(similarityFunction, hyperPlane, X.slice((long)from.getIndices().get(i).intValue()));
            distances.add(cosineSim);
        }
        Collections.sort(distances);
        from.setMedian((Double)distances.get(distances.size() / 2));
        for (i = 0; i < from.getIndices().size(); ++i) {
            cosineSim = RPUtils.computeDistance(similarityFunction, hyperPlane, X.slice((long)from.getIndices().get(i).intValue()));
            if (cosineSim <= from.getMedian()) {
                left.getIndices().add(from.getIndices().get(i));
                continue;
            }
            right.getIndices().add(from.getIndices().get(i));
        }
        if (left.getIndices().isEmpty() || right.getIndices().isEmpty()) {
            RPUtils.slimNode(from);
            return;
        }
        from.setLeft(left);
        from.setRight(right);
        RPUtils.slimNode(from);
        RPUtils.buildTree(tree, left, planes, X, maxSize, depth + 1, similarityFunction);
        RPUtils.buildTree(tree, right, planes, X, maxSize, depth + 1, similarityFunction);
    }

    public static void scanForLeaves(List<RPNode> nodes, RPTree scan) {
        RPUtils.scanForLeaves(nodes, scan.getRoot());
    }

    public static void scanForLeaves(List<RPNode> nodes, RPNode current) {
        if (current.getLeft() == null && current.getRight() == null) {
            nodes.add(current);
        }
        if (current.getLeft() != null) {
            RPUtils.scanForLeaves(nodes, current.getLeft());
        }
        if (current.getRight() != null) {
            RPUtils.scanForLeaves(nodes, current.getRight());
        }
    }

    public static void slimNode(RPNode node) {
        if (node.getRight() != null && node.getLeft() != null) {
            node.getIndices().clear();
        }
    }
}

