/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.vptree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.PriorityQueue;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.HeapItem;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class VPTree {
    public static final String EUCLIDEAN = "euclidean";
    private List<DataPoint> items;
    private double tau;
    private Node root;
    private CounterMap<DataPoint, DataPoint> distances;
    private String similarityFunction;
    private boolean invert = true;

    public VPTree(INDArray items, String similarityFunction, boolean invert) {
        ArrayList<DataPoint> thisItems = new ArrayList<DataPoint>();
        this.similarityFunction = similarityFunction;
        this.invert = invert;
        for (int i = 0; i < items.slices(); ++i) {
            thisItems.add(new DataPoint(i, items.slice(i), this.similarityFunction, invert));
        }
        this.items = thisItems;
        this.distances = CounterMap.runPairWise(thisItems, new CounterMap.CountFunction<DataPoint>(){

            @Override
            public double count(DataPoint v1, DataPoint v2) {
                return v1.distance(v2);
            }
        });
        this.root = this.buildFromPoints(0, this.items.size());
    }

    public VPTree(List<DataPoint> items, CounterMap<DataPoint, DataPoint> distances, String similarityFunction, boolean invert) {
        this.items = items;
        this.distances = distances;
        this.invert = invert;
        this.similarityFunction = similarityFunction;
        this.root = this.buildFromPoints(0, items.size());
    }

    public VPTree(List<DataPoint> items, String similarityFunction, boolean invert) {
        this.items = items;
        this.invert = invert;
        this.similarityFunction = similarityFunction;
        this.distances = CounterMap.runPairWise(items, new CounterMap.CountFunction<DataPoint>(){

            @Override
            public double count(DataPoint v1, DataPoint v2) {
                return v1.distance(v2);
            }
        });
        this.root = this.buildFromPoints(0, items.size());
    }

    public VPTree(INDArray items, String similarityFunction) {
        this(items, similarityFunction, true);
    }

    public VPTree(List<DataPoint> items, CounterMap<DataPoint, DataPoint> distances, String similarityFunction) {
        this(items, distances, similarityFunction, true);
    }

    public VPTree(List<DataPoint> items, String similarityFunction) {
        this(items, similarityFunction, true);
    }

    public VPTree(INDArray items) {
        this(items, EUCLIDEAN);
    }

    public VPTree(List<DataPoint> items, CounterMap<DataPoint, DataPoint> distances) {
        this(items, distances, EUCLIDEAN);
    }

    public VPTree(List<DataPoint> items) {
        this(items, EUCLIDEAN);
    }

    public static INDArray buildFromData(List<DataPoint> data) {
        INDArray ret = Nd4j.create((int)data.size(), (int)data.get(0).getD());
        for (int i = 0; i < ret.slices(); ++i) {
            ret.putSlice(i, data.get(i).getPoint());
        }
        return ret;
    }

    public List<DataPoint> getItems() {
        return this.items;
    }

    public void setItems(List<DataPoint> items) {
        this.items = items;
    }

    private double getDistance(DataPoint d1, DataPoint d2) {
        double count = this.distances.getCount(d1, d2);
        if (count == 0.0) {
            double realDistance = d1.distance(d2);
            this.distances.setCount(d1, d2, realDistance);
            this.distances.setCount(d2, d1, realDistance);
            return realDistance;
        }
        return count;
    }

    private Node buildFromPoints(int lower, int upper) {
        if (upper == lower) {
            return null;
        }
        Node ret = new Node(lower, 0.0);
        if (upper - lower > 1) {
            int i;
            int randomPoint = MathUtils.randomNumberBetween(lower, upper - 1);
            int median = (upper + lower) / 2;
            double[] distances = new double[this.items.size()];
            double[] sortedDistances = new double[this.items.size()];
            DataPoint basePoint = this.items.get(randomPoint);
            for (int i2 = 0; i2 < this.items.size(); ++i2) {
                distances[i2] = this.getDistance(basePoint, this.items.get(i2));
                sortedDistances[i2] = distances[i2];
            }
            Arrays.sort(sortedDistances);
            double medianDistance = sortedDistances[sortedDistances.length / 2];
            ArrayList<DataPoint> leftPoints = new ArrayList<DataPoint>(sortedDistances.length);
            ArrayList<DataPoint> rightPoints = new ArrayList<DataPoint>(sortedDistances.length);
            for (i = 0; i < distances.length; ++i) {
                if (distances[i] < medianDistance) {
                    leftPoints.add(this.items.get(i));
                    continue;
                }
                rightPoints.add(this.items.get(i));
            }
            for (i = 0; i < leftPoints.size(); ++i) {
                this.items.set(i, (DataPoint)leftPoints.get(i));
            }
            for (i = 0; i < rightPoints.size(); ++i) {
                this.items.set(i + leftPoints.size(), (DataPoint)rightPoints.get(i));
            }
            ret.setThreshold(this.getDistance(this.items.get(lower), this.items.get(median)));
            ret.setIndex(lower);
            ret.setLeft(this.buildFromPoints(lower + 1, median));
            ret.setRight(this.buildFromPoints(median, upper));
        }
        return ret;
    }

    public void search(DataPoint target, int k, List<DataPoint> results, List<Double> distances) {
        PriorityQueue<HeapItem> pq = new PriorityQueue<HeapItem>();
        this.tau = Double.MAX_VALUE;
        this.search(this.root, target, k, pq);
        results.clear();
        distances.clear();
        while (!pq.isEmpty()) {
            results.add(this.items.get(pq.peek().getIndex()));
            distances.add(pq.peek().getDistance());
            pq.next();
        }
        Collections.reverse(results);
        Collections.reverse(distances);
    }

    public void search(Node node, DataPoint target, int k, PriorityQueue<HeapItem> pq) {
        if (node == null) {
            return;
        }
        DataPoint get = this.items.get(node.getIndex());
        double distance = this.getDistance(get, target);
        if (distance < this.tau) {
            if (pq.size() == k) {
                pq.next();
            }
            pq.add(new HeapItem(node.index, distance), distance);
            if (pq.size() == k) {
                this.tau = pq.peek().getDistance();
            }
        }
        if (node.getLeft() == null && node.getRight() == null) {
            return;
        }
        if (distance < node.getThreshold()) {
            if (distance - this.tau <= node.getThreshold()) {
                this.search(node.getLeft(), target, k, pq);
            }
            if (distance + this.tau >= node.getThreshold()) {
                this.search(node.getRight(), target, k, pq);
            }
        } else {
            if (distance + this.tau >= node.getThreshold()) {
                this.search(node.getRight(), target, k, pq);
            }
            if (distance - this.tau <= node.getThreshold()) {
                this.search(node.getLeft(), target, k, pq);
            }
        }
    }

    public CounterMap<DataPoint, DataPoint> getDistances() {
        return this.distances;
    }

    public void setDistances(CounterMap<DataPoint, DataPoint> distances) {
        this.distances = distances;
    }

    public static class Node {
        private int index;
        private double threshold;
        private Node left;
        private Node right;

        public Node(int index, double threshold) {
            this.index = index;
            this.threshold = threshold;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Node node = (Node)o;
            if (this.index != node.index) {
                return false;
            }
            if (Double.compare(node.threshold, this.threshold) != 0) {
                return false;
            }
            if (this.left != null ? !this.left.equals(node.left) : node.left != null) {
                return false;
            }
            return !(this.right == null ? node.right != null : !this.right.equals(node.right));
        }

        public int hashCode() {
            int result = this.index;
            long temp = Double.doubleToLongBits(this.threshold);
            result = 31 * result + (int)(temp ^ temp >>> 32);
            result = 31 * result + (this.left != null ? this.left.hashCode() : 0);
            result = 31 * result + (this.right != null ? this.right.hashCode() : 0);
            return result;
        }

        public int getIndex() {
            return this.index;
        }

        public void setIndex(int index) {
            this.index = index;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public void setThreshold(double threshold) {
            this.threshold = threshold;
        }

        public Node getLeft() {
            return this.left;
        }

        public void setLeft(Node left) {
            this.left = left;
        }

        public Node getRight() {
            return this.right;
        }

        public void setRight(Node right) {
            this.right = right;
        }
    }
}

