/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.vptree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.clustering.vptree.VpTreePoint;

public final class VpTreeNode<T extends VpTreePoint<T>> {
    private static final int MAX_LEAF_SIZE = 25;
    private static final int VANTAGE_POINT_CANDIDATES = 5;
    private static final int TEST_POINT_COUNT = 15;
    private VpTreeNode<T> left = null;
    private VpTreeNode<T> right = null;
    private T vantagePoint = null;
    private double leftRadius = 0.0;
    private final List<T> points;

    public VpTreeNode(List<T> points) {
        this.points = points;
    }

    public Counter<T> findNearByPointsWithDistancesK(T point, int k) {
        if (k <= 0) {
            throw new IllegalArgumentException("Illegal k, must be >= 0");
        }
        return this.findNearbyPoints(new Counter(), point, k);
    }

    private Counter<T> findNearbyPoints(Counter<T> solution, T point, int k) {
        if (solution.size() >= k) {
            return solution;
        }
        if (this.left == null || this.vantagePoint == null) {
            for (VpTreePoint p : this.points) {
                solution.incrementCount(p, p.distance(point));
            }
            return solution;
        }
        double distanceToLeftCenter = this.vantagePoint.distance(point);
        if (distanceToLeftCenter < this.leftRadius) {
            super.findNearbyPoints(solution, point, k);
        } else if (distanceToLeftCenter >= this.leftRadius) {
            super.findNearbyPoints(solution, point, k);
        } else {
            super.findNearbyPoints(solution, point, k);
            super.findNearbyPoints(solution, point, k);
        }
        solution.keepBottomNKeys(k - 1);
        return solution;
    }

    public List<T> findNearByPointsK(T point, int k) {
        if (k <= 0) {
            throw new IllegalArgumentException("Illegal k, must be >= 0");
        }
        return this.findNearbyPoints(new ArrayList(), point, k);
    }

    private List<T> findNearbyPoints(List<T> solution, T point, int k) {
        double distanceToLeftCenter;
        if (solution.size() >= k) {
            return solution;
        }
        Counter<VpTreePoint> counter = new Counter<VpTreePoint>();
        if (this.left == null) {
            ArrayList<VpTreePoint<Object>> result = new ArrayList<VpTreePoint<Object>>();
            for (VpTreePoint p : this.points) {
                result.add(p);
                counter.incrementCount(p, p.distance(point));
            }
            counter.keepBottomNKeys(k);
            result.addAll(counter.getSortedKeys());
            solution.addAll(result);
        }
        if ((distanceToLeftCenter = this.vantagePoint.distance(point)) < this.leftRadius) {
            solution.addAll(this.left.findNearbyPoints((T)point, k));
        } else if (distanceToLeftCenter >= this.leftRadius) {
            solution.addAll(this.right.findNearbyPoints(point, k));
        } else {
            List<T> result = this.right.findNearbyPoints(point, k);
            result.addAll(this.left.findNearbyPoints(point, k));
            solution.addAll(result);
        }
        Counter<VpTreePoint> c = new Counter<VpTreePoint>();
        for (VpTreePoint t : solution) {
            c.incrementCount(t, point.distance((VpTreePoint)t));
        }
        c.keepBottomNKeys(k);
        solution.clear();
        solution.addAll(c.getSortedKeys());
        solution = solution.subList(0, k);
        return solution;
    }

    public List<T> findNearbyPoints(T point, int k) {
        Counter<VpTreePoint> counter = new Counter<VpTreePoint>();
        if (this.left == null) {
            ArrayList<VpTreePoint<Object>> result = new ArrayList<VpTreePoint<Object>>();
            for (VpTreePoint p : this.points) {
                result.add(p);
                counter.incrementCount(p, p.distance(point));
            }
            counter.keepBottomNKeys(k);
            result.addAll(counter.getSortedKeys());
            return result;
        }
        double distanceToLeftCenter = this.vantagePoint.distance(point);
        if (distanceToLeftCenter < this.leftRadius) {
            return this.left.findNearbyPoints(point, k);
        }
        if (distanceToLeftCenter >= this.leftRadius) {
            return this.right.findNearbyPoints(point, k);
        }
        List<T> result = this.right.findNearbyPoints(point, k);
        result.addAll(this.left.findNearbyPoints(point, k));
        return result;
    }

    public List<T> findNearbyPoints(T point, double maxDistance) {
        if (this.left == null) {
            ArrayList<VpTreePoint> result = new ArrayList<VpTreePoint>();
            for (VpTreePoint p : this.points) {
                if (!(point.distance((VpTreePoint)p) <= maxDistance)) continue;
                result.add(p);
            }
            return result;
        }
        double distanceToLeftCenter = this.vantagePoint.distance(point);
        if (distanceToLeftCenter + maxDistance < this.leftRadius) {
            return this.left.findNearbyPoints(point, maxDistance);
        }
        if (distanceToLeftCenter - maxDistance >= this.leftRadius) {
            return this.right.findNearbyPoints(point, maxDistance);
        }
        List<T> result = this.right.findNearbyPoints(point, maxDistance);
        result.addAll(this.left.findNearbyPoints(point, maxDistance));
        return result;
    }

    public static <T extends VpTreePoint<T>> VpTreeNode<T> buildVpTree(List<T> points) {
        return VpTreeNode.buildTreeNode(new ArrayList<T>(points));
    }

    private static <T extends VpTreePoint<T>> VpTreeNode<T> buildTreeNode(List<T> points) {
        int i;
        VpTreeNode<T> node = new VpTreeNode<T>(points);
        if (points.size() < 25) {
            return node;
        }
        VpTreePoint basePoint = VpTreeNode.chooseNewVantagePoint(points);
        double[] distances = new double[points.size()];
        double[] sortedDistances = new double[points.size()];
        for (int i2 = 0; i2 < points.size(); ++i2) {
            distances[i2] = basePoint.distance((VpTreePoint)((VpTreePoint)points.get(i2)));
            sortedDistances[i2] = distances[i2];
        }
        Arrays.sort(sortedDistances);
        double medianDistance = sortedDistances[sortedDistances.length / 2];
        ArrayList<T> leftPoints = new ArrayList<T>(sortedDistances.length);
        ArrayList<T> rightPoints = new ArrayList<T>(sortedDistances.length);
        for (i = 0; i < distances.length; ++i) {
            if (distances[i] < medianDistance) {
                leftPoints.add(points.get(i));
                continue;
            }
            rightPoints.add(points.get(i));
        }
        for (i = 0; i < leftPoints.size(); ++i) {
            points.set(i, leftPoints.get(i));
        }
        for (i = 0; i < rightPoints.size(); ++i) {
            points.set(i + leftPoints.size(), rightPoints.get(i));
        }
        node.vantagePoint = basePoint;
        node.leftRadius = medianDistance;
        node.left = VpTreeNode.buildTreeNode(points.subList(0, leftPoints.size()));
        node.right = VpTreeNode.buildTreeNode(points.subList(leftPoints.size(), points.size()));
        return node;
    }

    private static <T extends VpTreePoint<T>> T chooseNewVantagePoint(List<T> points) {
        int i;
        ArrayList<VpTreePoint> candidates = new ArrayList<VpTreePoint>(5);
        ArrayList<VpTreePoint> testPoints = new ArrayList<VpTreePoint>(15);
        for (i = 0; i < 5; ++i) {
            int basePointIndex = i + (int)(Math.random() * (double)(points.size() - i));
            VpTreePoint candidate = (VpTreePoint)points.get(basePointIndex);
            points.set(basePointIndex, points.get(i));
            points.set(i, candidate);
            candidates.add(candidate);
        }
        for (i = 5; i < 20; ++i) {
            int testPointIndex = i + (int)(Math.random() * (double)(points.size() - i));
            VpTreePoint testPoint = (VpTreePoint)points.get(testPointIndex);
            points.set(testPointIndex, points.get(i));
            points.set(i, testPoint);
            testPoints.add(testPoint);
        }
        VpTreePoint bestBasePoint = (VpTreePoint)points.get(0);
        double bestBasePointSigma = 0.0;
        for (VpTreePoint basePoint : candidates) {
            double[] distances = new double[15];
            for (int i2 = 0; i2 < 15; ++i2) {
                distances[i2] = basePoint.distance((VpTreePoint)testPoints.get(i2));
            }
            double sigma = VpTreeNode.sigmaSquare(distances);
            if (!(sigma > bestBasePointSigma)) continue;
            bestBasePointSigma = sigma;
            bestBasePoint = basePoint;
        }
        return (T)bestBasePoint;
    }

    private static double sigmaSquare(double[] values) {
        double sum = 0.0;
        for (double value : values) {
            sum += value;
        }
        double avg = sum / (double)values.length;
        double sigmaSq = 0.0;
        for (double value : values) {
            double dev = value - avg;
            sigmaSq += dev * dev;
        }
        return sigmaSq;
    }
}

