/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.vptree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.HeapObject;
import org.deeplearning4j.clustering.util.MathUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VPTree
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(VPTree.class);
    private static final long serialVersionUID = 1L;
    public static final String EUCLIDEAN = "euclidean";
    private double tau;
    private INDArray items;
    private List<INDArray> itemsList;
    private Node root;
    private String similarityFunction;
    private boolean invert = false;
    private transient ExecutorService executorService;
    private int workers = 1;
    private AtomicInteger size = new AtomicInteger(0);
    private transient ThreadLocal<INDArray> scalars = new ThreadLocal();
    private WorkspaceConfiguration workspaceConfiguration;

    protected VPTree() {
        this.scalars = new ThreadLocal();
    }

    public VPTree(INDArray points, boolean invert) {
        this(points, EUCLIDEAN, 1, invert);
    }

    public VPTree(INDArray points, boolean invert, int workers) {
        this(points, EUCLIDEAN, workers, invert);
    }

    public VPTree(INDArray items, String similarityFunction, boolean invert) {
        this.similarityFunction = similarityFunction;
        this.invert = invert;
        this.items = items;
        this.root = this.buildFromPoints(items);
        this.workers = 1;
    }

    public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) {
        this.workers = workers;
        INDArray[] list = new INDArray[items.size()];
        for (int i = 0; i < items.size(); ++i) {
            list[i] = items.get(i).getPoint();
        }
        this.items = Nd4j.pile((INDArray[])list);
        this.invert = invert;
        this.similarityFunction = similarityFunction;
        this.root = this.buildFromPoints(this.items);
    }

    public VPTree(INDArray items, String similarityFunction) {
        this(items, similarityFunction, 1, false);
    }

    public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) {
        this.similarityFunction = similarityFunction;
        this.invert = invert;
        this.items = items;
        this.workers = workers;
        this.root = this.buildFromPoints(items);
    }

    public VPTree(List<DataPoint> items, String similarityFunction) {
        this(items, similarityFunction, 1, false);
    }

    public VPTree(INDArray items) {
        this(items, EUCLIDEAN);
    }

    public VPTree(List<DataPoint> items) {
        this(items, EUCLIDEAN);
    }

    public static INDArray buildFromData(List<DataPoint> data) {
        INDArray ret = Nd4j.create((long[])new long[]{data.size(), data.get(0).getD()});
        int i = 0;
        while ((long)i < ret.slices()) {
            ret.putSlice(i, data.get(i).getPoint());
            ++i;
        }
        return ret;
    }

    public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) {
        switch (this.similarityFunction) {
            case "euclidean": {
                Nd4j.getExecutioner().exec((ReduceOp)new EuclideanDistance(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            case "cosinedistance": {
                Nd4j.getExecutioner().exec((ReduceOp)new CosineDistance(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            case "cosinesimilarity": {
                Nd4j.getExecutioner().exec((ReduceOp)new CosineSimilarity(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            case "manhattan": {
                Nd4j.getExecutioner().exec((ReduceOp)new ManhattanDistance(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            case "dot": {
                Nd4j.getExecutioner().exec((ReduceOp)new Dot(items, basePoint, distancesArr, new int[]{-1}));
                break;
            }
            case "jaccard": {
                Nd4j.getExecutioner().exec((ReduceOp)new JaccardDistance(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            case "hamming": {
                Nd4j.getExecutioner().exec((ReduceOp)new HammingDistance(items, basePoint, distancesArr, true, new int[]{-1}));
                break;
            }
            default: {
                Nd4j.getExecutioner().exec((ReduceOp)new EuclideanDistance(items, basePoint, distancesArr, true, new int[]{-1}));
            }
        }
        if (this.invert) {
            distancesArr.negi();
        }
    }

    public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) {
        this.calcDistancesRelativeTo(this.items, basePoint, distancesArr);
    }

    public double distance(INDArray arr1, INDArray arr2) {
        if (this.scalars == null) {
            this.scalars = new ThreadLocal();
        }
        if (this.scalars.get() == null) {
            this.scalars.set(Nd4j.scalar((DataType)arr1.dataType(), (Number)0.0));
        }
        switch (this.similarityFunction) {
            case "jaccard": {
                double ret7 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new JaccardDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret7 : ret7;
            }
            case "hamming": {
                double ret8 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new HammingDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret8 : ret8;
            }
            case "euclidean": {
                double ret = Nd4j.getExecutioner().execAndReturn((ReduceOp)new EuclideanDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret : ret;
            }
            case "cosinesimilarity": {
                double ret2 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new CosineSimilarity(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret2 : ret2;
            }
            case "cosinedistance": {
                double ret6 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new CosineDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret6 : ret6;
            }
            case "manhattan": {
                double ret3 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new ManhattanDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
                return this.invert ? -ret3 : ret3;
            }
            case "dot": {
                double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2);
                return this.invert ? -dotRet : dotRet;
            }
        }
        double ret4 = Nd4j.getExecutioner().execAndReturn((ReduceOp)new EuclideanDistance(arr1, arr2, this.scalars.get())).getFinalResult().doubleValue();
        return this.invert ? -ret4 : ret4;
    }

    private Node buildFromPoints(List<INDArray> points, List<Integer> indices) {
        Node ret = new Node(0, 0.0f);
        if (points.size() == 1) {
            ret.point = points.get(0);
            ret.index = indices.get(0);
            return ret;
        }
        INDArray items = Nd4j.vstack(points);
        int randomPoint = MathUtils.randomNumberBetween(0.0, (double)(items.rows() - 1), Nd4j.getRandom());
        INDArray basePoint = points.get(randomPoint);
        ret.point = basePoint;
        ret.index = indices.get(randomPoint);
        INDArray distancesArr = Nd4j.create((int[])new int[]{items.rows(), 1});
        this.calcDistancesRelativeTo(items, basePoint, distancesArr);
        double medianDistance = distancesArr.medianNumber().doubleValue();
        ret.threshold = (float)medianDistance;
        ArrayList<INDArray> leftPoints = new ArrayList<INDArray>();
        ArrayList<Integer> leftIndices = new ArrayList<Integer>();
        ArrayList<INDArray> rightPoints = new ArrayList<INDArray>();
        ArrayList<Integer> rightIndices = new ArrayList<Integer>();
        int i = 0;
        while ((long)i < distancesArr.length()) {
            if (i != randomPoint) {
                if (distancesArr.getDouble((long)i) < medianDistance) {
                    leftPoints.add(points.get(i));
                    leftIndices.add(indices.get(i));
                } else {
                    rightPoints.add(points.get(i));
                    rightIndices.add(indices.get(i));
                }
            }
            ++i;
        }
        if (this.workers > 1) {
            if (!leftPoints.isEmpty()) {
                ret.futureLeft = this.executorService.submit(new NodeBuilder(leftPoints, leftIndices));
            }
            if (!rightPoints.isEmpty()) {
                ret.futureRight = this.executorService.submit(new NodeBuilder(rightPoints, rightIndices));
            }
        } else {
            if (!leftPoints.isEmpty()) {
                ret.left = this.buildFromPoints(leftPoints, leftIndices);
            }
            if (!rightPoints.isEmpty()) {
                ret.right = this.buildFromPoints(rightPoints, rightIndices);
            }
        }
        return ret;
    }

    private Node buildFromPoints(INDArray items) {
        if (this.executorService == null && items == this.items && this.workers > 1) {
            final Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            this.executorService = Executors.newFixedThreadPool(this.workers, new ThreadFactory(){

                @Override
                public Thread newThread(final Runnable r) {
                    Thread t = new Thread(new Runnable(){

                        @Override
                        public void run() {
                            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
                            r.run();
                        }
                    });
                    t.setDaemon(true);
                    t.setName("VPTree thread");
                    return t;
                }
            });
        }
        Node ret = new Node(0, 0.0f);
        this.size.incrementAndGet();
        int randomPoint = MathUtils.randomNumberBetween(0.0, (double)(items.rows() - 1), Nd4j.getRandom());
        INDArray basePoint = items.getRow((long)randomPoint, true);
        INDArray distancesArr = Nd4j.create((int[])new int[]{items.rows(), 1});
        ret.point = basePoint;
        ret.index = randomPoint;
        this.calcDistancesRelativeTo(items, basePoint, distancesArr);
        double medianDistance = distancesArr.medianNumber().doubleValue();
        ret.threshold = (float)medianDistance;
        ArrayList<INDArray> leftPoints = new ArrayList<INDArray>();
        ArrayList<Integer> leftIndices = new ArrayList<Integer>();
        ArrayList<INDArray> rightPoints = new ArrayList<INDArray>();
        ArrayList<Integer> rightIndices = new ArrayList<Integer>();
        int i = 0;
        while ((long)i < distancesArr.length()) {
            if (i != randomPoint) {
                if (distancesArr.getDouble((long)i) < medianDistance) {
                    leftPoints.add(items.getRow((long)i, true));
                    leftIndices.add(i);
                } else {
                    rightPoints.add(items.getRow((long)i, true));
                    rightIndices.add(i);
                }
            }
            ++i;
        }
        if (!leftPoints.isEmpty()) {
            ret.left = this.buildFromPoints(leftPoints, leftIndices);
        }
        if (!rightPoints.isEmpty()) {
            ret.right = this.buildFromPoints(rightPoints, rightIndices);
        }
        if (ret.left != null) {
            ret.left.fetchFutures();
        }
        if (ret.right != null) {
            ret.right.fetchFutures();
        }
        if (this.executorService != null) {
            this.executorService.shutdown();
        }
        return ret;
    }

    public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) {
        if (target == null) {
            throw new NullPointerException("target is marked non-null but is null");
        }
        this.search(target, k, results, distances, true);
    }

    public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, boolean filterEqual) {
        if (target == null) {
            throw new NullPointerException("target is marked non-null but is null");
        }
        this.search(target, k, results, distances, filterEqual, false);
    }

    public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, boolean filterEqual, boolean dropEdge) {
        if (target == null) {
            throw new NullPointerException("target is marked non-null but is null");
        }
        if (!(this.items == null || target.isVectorOrScalar() && target.columns() == this.items.columns() && target.rows() <= 1)) {
            throw new ND4JIllegalStateException("Target for search should have shape of [1, " + this.items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead");
        }
        k = Math.min(k, this.items.rows());
        results.clear();
        distances.clear();
        PriorityQueue<HeapObject> pq = new PriorityQueue<HeapObject>(this.items.rows(), new HeapObjectComparator());
        this.search(this.root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE);
        while (!pq.isEmpty()) {
            HeapObject ho = pq.peek();
            results.add(new DataPoint(ho.getIndex(), ho.getPoint()));
            distances.add(ho.getDistance());
            pq.poll();
        }
        Collections.reverse(results);
        Collections.reverse(distances);
        if (dropEdge || results.size() > k) {
            if (filterEqual && distances.get(0) == 0.0) {
                results.remove(0);
                distances.remove(0);
            }
            while (results.size() > k) {
                results.remove(results.size() - 1);
                distances.remove(distances.size() - 1);
            }
        }
    }

    public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) {
        if (node == null) {
            return;
        }
        double tau = cTau;
        INDArray get = node.getPoint();
        double distance = this.distance(get, target);
        if (distance < tau) {
            if (pq.size() == k) {
                pq.poll();
            }
            pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance));
            if (pq.size() == k) {
                tau = pq.peek().getDistance();
            }
        }
        Node left = node.getLeft();
        Node right = node.getRight();
        if (left == null && right == null) {
            return;
        }
        if (distance < (double)node.getThreshold()) {
            if (distance - tau < (double)node.getThreshold()) {
                this.search(left, target, k, pq, tau);
            }
            if (distance + tau >= (double)node.getThreshold()) {
                this.search(right, target, k, pq, tau);
            }
        } else {
            if (distance + tau >= (double)node.getThreshold()) {
                this.search(right, target, k, pq, tau);
            }
            if (distance - tau < (double)node.getThreshold()) {
                this.search(left, target, k, pq, tau);
            }
        }
    }

    public static VPTreeBuilder builder() {
        return new VPTreeBuilder();
    }

    public VPTree(double tau, INDArray items, List<INDArray> itemsList, Node root, String similarityFunction, boolean invert, ExecutorService executorService, int workers, AtomicInteger size, ThreadLocal<INDArray> scalars, WorkspaceConfiguration workspaceConfiguration) {
        this.tau = tau;
        this.items = items;
        this.itemsList = itemsList;
        this.root = root;
        this.similarityFunction = similarityFunction;
        this.invert = invert;
        this.executorService = executorService;
        this.workers = workers;
        this.size = size;
        this.scalars = scalars;
        this.workspaceConfiguration = workspaceConfiguration;
    }

    public INDArray getItems() {
        return this.items;
    }

    public void setItems(INDArray items) {
        this.items = items;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public int getWorkers() {
        return this.workers;
    }

    public static class VPTreeBuilder {
        private double tau;
        private INDArray items;
        private List<INDArray> itemsList;
        private Node root;
        private String similarityFunction;
        private boolean invert;
        private ExecutorService executorService;
        private int workers;
        private AtomicInteger size;
        private ThreadLocal<INDArray> scalars;
        private WorkspaceConfiguration workspaceConfiguration;

        VPTreeBuilder() {
        }

        public VPTreeBuilder tau(double tau) {
            this.tau = tau;
            return this;
        }

        public VPTreeBuilder items(INDArray items) {
            this.items = items;
            return this;
        }

        public VPTreeBuilder itemsList(List<INDArray> itemsList) {
            this.itemsList = itemsList;
            return this;
        }

        public VPTreeBuilder root(Node root) {
            this.root = root;
            return this;
        }

        public VPTreeBuilder similarityFunction(String similarityFunction) {
            this.similarityFunction = similarityFunction;
            return this;
        }

        public VPTreeBuilder invert(boolean invert) {
            this.invert = invert;
            return this;
        }

        public VPTreeBuilder executorService(ExecutorService executorService) {
            this.executorService = executorService;
            return this;
        }

        public VPTreeBuilder workers(int workers) {
            this.workers = workers;
            return this;
        }

        public VPTreeBuilder size(AtomicInteger size) {
            this.size = size;
            return this;
        }

        public VPTreeBuilder scalars(ThreadLocal<INDArray> scalars) {
            this.scalars = scalars;
            return this;
        }

        public VPTreeBuilder workspaceConfiguration(WorkspaceConfiguration workspaceConfiguration) {
            this.workspaceConfiguration = workspaceConfiguration;
            return this;
        }

        public VPTree build() {
            return new VPTree(this.tau, this.items, this.itemsList, this.root, this.similarityFunction, this.invert, this.executorService, this.workers, this.size, this.scalars, this.workspaceConfiguration);
        }

        public String toString() {
            return "VPTree.VPTreeBuilder(tau=" + this.tau + ", items=" + this.items + ", itemsList=" + this.itemsList + ", root=" + this.root + ", similarityFunction=" + this.similarityFunction + ", invert=" + this.invert + ", executorService=" + this.executorService + ", workers=" + this.workers + ", size=" + this.size + ", scalars=" + this.scalars + ", workspaceConfiguration=" + this.workspaceConfiguration + ")";
        }
    }

    public static class Node
    implements Serializable {
        private static final long serialVersionUID = 2L;
        private int index;
        private float threshold;
        private Node left;
        private Node right;
        private INDArray point;
        protected transient Future<Node> futureLeft;
        protected transient Future<Node> futureRight;

        public Node(int index, float threshold) {
            this.index = index;
            this.threshold = threshold;
        }

        public void fetchFutures() {
            try {
                if (this.futureLeft != null) {
                    this.left = this.futureLeft.get();
                }
                if (this.futureRight != null) {
                    this.right = this.futureRight.get();
                }
                if (this.left != null) {
                    this.left.fetchFutures();
                }
                if (this.right != null) {
                    this.right.fetchFutures();
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public int getIndex() {
            return this.index;
        }

        public float getThreshold() {
            return this.threshold;
        }

        public Node getLeft() {
            return this.left;
        }

        public Node getRight() {
            return this.right;
        }

        public INDArray getPoint() {
            return this.point;
        }

        public Future<Node> getFutureLeft() {
            return this.futureLeft;
        }

        public Future<Node> getFutureRight() {
            return this.futureRight;
        }

        public void setIndex(int index) {
            this.index = index;
        }

        public void setThreshold(float threshold) {
            this.threshold = threshold;
        }

        public void setLeft(Node left) {
            this.left = left;
        }

        public void setRight(Node right) {
            this.right = right;
        }

        public void setPoint(INDArray point) {
            this.point = point;
        }

        public void setFutureLeft(Future<Node> futureLeft) {
            this.futureLeft = futureLeft;
        }

        public void setFutureRight(Future<Node> futureRight) {
            this.futureRight = futureRight;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Node)) {
                return false;
            }
            Node other = (Node)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getIndex() != other.getIndex()) {
                return false;
            }
            if (Float.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            Node this$left = this.getLeft();
            Node other$left = other.getLeft();
            if (this$left == null ? other$left != null : !((Object)this$left).equals(other$left)) {
                return false;
            }
            Node this$right = this.getRight();
            Node other$right = other.getRight();
            if (this$right == null ? other$right != null : !((Object)this$right).equals(other$right)) {
                return false;
            }
            INDArray this$point = this.getPoint();
            INDArray other$point = other.getPoint();
            return !(this$point == null ? other$point != null : !this$point.equals(other$point));
        }

        protected boolean canEqual(Object other) {
            return other instanceof Node;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getIndex();
            result = result * 59 + Float.floatToIntBits(this.getThreshold());
            Node $left = this.getLeft();
            result = result * 59 + ($left == null ? 43 : ((Object)$left).hashCode());
            Node $right = this.getRight();
            result = result * 59 + ($right == null ? 43 : ((Object)$right).hashCode());
            INDArray $point = this.getPoint();
            result = result * 59 + ($point == null ? 43 : $point.hashCode());
            return result;
        }

        public String toString() {
            return "VPTree.Node(index=" + this.getIndex() + ", threshold=" + this.getThreshold() + ", left=" + this.getLeft() + ", right=" + this.getRight() + ", point=" + this.getPoint() + ", futureLeft=" + this.getFutureLeft() + ", futureRight=" + this.getFutureRight() + ")";
        }
    }

    protected class HeapObjectComparator
    implements Comparator<HeapObject> {
        protected HeapObjectComparator() {
        }

        @Override
        public int compare(HeapObject o1, HeapObject o2) {
            return Double.compare(o2.getDistance(), o1.getDistance());
        }
    }

    protected class NodeBuilder
    implements Callable<Node> {
        protected List<INDArray> list;
        protected List<Integer> indices;

        public NodeBuilder(List<INDArray> list, List<Integer> indices) {
            this.list = list;
            this.indices = indices;
        }

        @Override
        public Node call() throws Exception {
            return VPTree.this.buildFromPoints(this.list, this.indices);
        }
    }
}

