package org.deeplearning4j.clustering.vptree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.clustering.vptree.VpTreePoint;

/* loaded from: input_file:org/deeplearning4j/clustering/vptree/VpTreeNode.class */
public final class VpTreeNode<T extends VpTreePoint<T>> {
    private static final int MAX_LEAF_SIZE = 25;
    private static final int VANTAGE_POINT_CANDIDATES = 5;
    private static final int TEST_POINT_COUNT = 15;
    private VpTreeNode<T> left = null;
    private VpTreeNode<T> right = null;
    private T vantagePoint = null;
    private double leftRadius = 0.0d;
    private final List<T> points;

    public VpTreeNode(List<T> list) {
        this.points = list;
    }

    public Counter<T> findNearByPointsWithDistancesK(T t, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Illegal k, must be >= 0");
        }
        return findNearbyPoints((Counter<Counter<T>>) new Counter<>(), (Counter<T>) t, i);
    }

    private Counter<T> findNearbyPoints(Counter<T> counter, T t, int i) {
        if (counter.size() >= i) {
            return counter;
        }
        if (this.left == null || this.vantagePoint == null) {
            for (T t2 : this.points) {
                counter.incrementCount(t2, t2.distance(t));
            }
            return counter;
        }
        double distance = this.vantagePoint.distance(t);
        if (distance < this.leftRadius) {
            this.left.findNearbyPoints((Counter<Counter<T>>) counter, (Counter<T>) t, i);
        } else if (distance >= this.leftRadius) {
            this.right.findNearbyPoints((Counter<Counter<T>>) counter, (Counter<T>) t, i);
        } else {
            this.right.findNearbyPoints((Counter<Counter<T>>) counter, (Counter<T>) t, i);
            this.left.findNearbyPoints((Counter<Counter<T>>) counter, (Counter<T>) t, i);
        }
        counter.keepBottomNKeys(i - 1);
        return counter;
    }

    public List<T> findNearByPointsK(T t, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Illegal k, must be >= 0");
        }
        return findNearbyPoints((List<ArrayList>) new ArrayList(), (ArrayList) t, i);
    }

    private List<T> findNearbyPoints(List<T> list, T t, int i) {
        if (list.size() >= i) {
            return list;
        }
        Counter counter = new Counter();
        if (this.left == null) {
            ArrayList arrayList = new ArrayList();
            for (T t2 : this.points) {
                arrayList.add(t2);
                counter.incrementCount(t2, t2.distance(t));
            }
            counter.keepBottomNKeys(i);
            arrayList.addAll(counter.getSortedKeys());
            list.addAll(arrayList);
        }
        double distance = this.vantagePoint.distance(t);
        if (distance < this.leftRadius) {
            list.addAll(this.left.findNearbyPoints((VpTreeNode<T>) t, i));
        } else if (distance >= this.leftRadius) {
            list.addAll(this.right.findNearbyPoints((VpTreeNode<T>) t, i));
        } else {
            List<T> findNearbyPoints = this.right.findNearbyPoints((VpTreeNode<T>) t, i);
            findNearbyPoints.addAll(this.left.findNearbyPoints((VpTreeNode<T>) t, i));
            list.addAll(findNearbyPoints);
        }
        Counter counter2 = new Counter();
        for (T t3 : list) {
            counter2.incrementCount(t3, t.distance(t3));
        }
        counter2.keepBottomNKeys(i);
        list.clear();
        list.addAll(counter2.getSortedKeys());
        return list.subList(0, i);
    }

    public List<T> findNearbyPoints(T t, int i) {
        Counter counter = new Counter();
        if (this.left != null) {
            double distance = this.vantagePoint.distance(t);
            if (distance < this.leftRadius) {
                return this.left.findNearbyPoints((VpTreeNode<T>) t, i);
            }
            if (distance >= this.leftRadius) {
                return this.right.findNearbyPoints((VpTreeNode<T>) t, i);
            }
            List<T> findNearbyPoints = this.right.findNearbyPoints((VpTreeNode<T>) t, i);
            findNearbyPoints.addAll(this.left.findNearbyPoints((VpTreeNode<T>) t, i));
            return findNearbyPoints;
        }
        ArrayList arrayList = new ArrayList();
        for (T t2 : this.points) {
            arrayList.add(t2);
            counter.incrementCount(t2, t2.distance(t));
        }
        counter.keepBottomNKeys(i);
        arrayList.addAll(counter.getSortedKeys());
        return arrayList;
    }

    public List<T> findNearbyPoints(T t, double d) {
        if (this.left == null) {
            ArrayList arrayList = new ArrayList();
            for (T t2 : this.points) {
                if (t.distance(t2) <= d) {
                    arrayList.add(t2);
                }
            }
            return arrayList;
        }
        double distance = this.vantagePoint.distance(t);
        if (distance + d < this.leftRadius) {
            return this.left.findNearbyPoints((VpTreeNode<T>) t, d);
        }
        if (distance - d >= this.leftRadius) {
            return this.right.findNearbyPoints((VpTreeNode<T>) t, d);
        }
        List<T> findNearbyPoints = this.right.findNearbyPoints((VpTreeNode<T>) t, d);
        findNearbyPoints.addAll(this.left.findNearbyPoints((VpTreeNode<T>) t, d));
        return findNearbyPoints;
    }

    public static <T extends VpTreePoint<T>> VpTreeNode<T> buildVpTree(List<T> list) {
        return buildTreeNode(new ArrayList(list));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T extends VpTreePoint<T>> VpTreeNode<T> buildTreeNode(List<T> list) {
        VpTreeNode<T> vpTreeNode = new VpTreeNode<>(list);
        if (list.size() < MAX_LEAF_SIZE) {
            return vpTreeNode;
        }
        VpTreePointINDArray vpTreePointINDArray = (T) chooseNewVantagePoint(list);
        double[] dArr = new double[list.size()];
        double[] dArr2 = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = vpTreePointINDArray.distance((VpTreePointINDArray) list.get(i));
            dArr2[i] = dArr[i];
        }
        Arrays.sort(dArr2);
        double d = dArr2[dArr2.length / 2];
        ArrayList arrayList = new ArrayList(dArr2.length);
        ArrayList arrayList2 = new ArrayList(dArr2.length);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] < d) {
                arrayList.add(list.get(i2));
            } else {
                arrayList2.add(list.get(i2));
            }
        }
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            list.set(i3, arrayList.get(i3));
        }
        for (int i4 = 0; i4 < arrayList2.size(); i4++) {
            list.set(i4 + arrayList.size(), arrayList2.get(i4));
        }
        ((VpTreeNode) vpTreeNode).vantagePoint = vpTreePointINDArray;
        ((VpTreeNode) vpTreeNode).leftRadius = d;
        ((VpTreeNode) vpTreeNode).left = buildTreeNode(list.subList(0, arrayList.size()));
        ((VpTreeNode) vpTreeNode).right = buildTreeNode(list.subList(arrayList.size(), list.size()));
        return vpTreeNode;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T extends VpTreePoint<T>> T chooseNewVantagePoint(List<T> list) {
        ArrayList<VpTreePoint> arrayList = new ArrayList(VANTAGE_POINT_CANDIDATES);
        ArrayList arrayList2 = new ArrayList(TEST_POINT_COUNT);
        for (int i = 0; i < VANTAGE_POINT_CANDIDATES; i++) {
            int random = i + ((int) (Math.random() * (list.size() - i)));
            T t = list.get(random);
            list.set(random, list.get(i));
            list.set(i, t);
            arrayList.add(t);
        }
        for (int i2 = VANTAGE_POINT_CANDIDATES; i2 < 20; i2++) {
            int random2 = i2 + ((int) (Math.random() * (list.size() - i2)));
            T t2 = list.get(random2);
            list.set(random2, list.get(i2));
            list.set(i2, t2);
            arrayList2.add(t2);
        }
        VpTreePoint vpTreePoint = list.get(0);
        double d = 0.0d;
        for (VpTreePoint vpTreePoint2 : arrayList) {
            double[] dArr = new double[TEST_POINT_COUNT];
            for (int i3 = 0; i3 < TEST_POINT_COUNT; i3++) {
                dArr[i3] = vpTreePoint2.distance((VpTreePoint) arrayList2.get(i3));
            }
            double sigmaSquare = sigmaSquare(dArr);
            if (sigmaSquare > d) {
                d = sigmaSquare;
                vpTreePoint = vpTreePoint2;
            }
        }
        return (T) vpTreePoint;
    }

    private static double sigmaSquare(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        double length = d / dArr.length;
        double d3 = 0.0d;
        for (double d4 : dArr) {
            double d5 = d4 - length;
            d3 += d5 * d5;
        }
        return d3;
    }
}
