package smile.neighbor;

import java.lang.reflect.Array;
import java.util.List;
import smile.math.Math;
import smile.sort.HeapSelect;

/* loaded from: input_file:smile/neighbor/KDTree.class */
public class KDTree<E> implements NearestNeighborSearch<double[], E>, KNNSearch<double[], E>, RNNSearch<double[], E> {
    private double[][] keys;
    private E[] data;
    private KDTree<E>.Node root;
    private int[] index;
    private boolean identicalExcluded = true;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/neighbor/KDTree$Node.class */
    public class Node {
        int count;
        int index;
        int split;
        double cutoff;
        KDTree<E>.Node lower;
        KDTree<E>.Node upper;

        Node() {
        }

        boolean isLeaf() {
            return this.lower == null && this.upper == null;
        }
    }

    public KDTree(double[][] dArr, E[] eArr) {
        if (dArr.length != eArr.length) {
            throw new IllegalArgumentException("The array size of keys and data are different.");
        }
        this.keys = dArr;
        this.data = eArr;
        int length = dArr.length;
        this.index = new int[length];
        for (int i = 0; i < length; i++) {
            this.index[i] = i;
        }
        this.root = buildNode(0, length);
    }

    public String toString() {
        return "KD-Tree";
    }

    private KDTree<E>.Node buildNode(int i, int i2) {
        int length = this.keys[0].length;
        KDTree<E>.Node node = new Node();
        node.count = i2 - i;
        node.index = i;
        double[] dArr = new double[length];
        double[] dArr2 = new double[length];
        for (int i3 = 0; i3 < length; i3++) {
            dArr[i3] = this.keys[this.index[i]][i3];
            dArr2[i3] = this.keys[this.index[i]][i3];
        }
        for (int i4 = i + 1; i4 < i2; i4++) {
            for (int i5 = 0; i5 < length; i5++) {
                double d = this.keys[this.index[i4]][i5];
                if (dArr[i5] > d) {
                    dArr[i5] = d;
                }
                if (dArr2[i5] < d) {
                    dArr2[i5] = d;
                }
            }
        }
        double d2 = -1.0d;
        for (int i6 = 0; i6 < length; i6++) {
            double d3 = (dArr2[i6] - dArr[i6]) / 2.0d;
            if (d3 > d2) {
                d2 = d3;
                node.split = i6;
                node.cutoff = (dArr2[i6] + dArr[i6]) / 2.0d;
            }
        }
        if (d2 == 0.0d) {
            node.upper = null;
            node.lower = null;
            return node;
        }
        int i7 = i;
        int i8 = i2 - 1;
        int i9 = 0;
        while (i7 <= i8) {
            boolean z = this.keys[this.index[i7]][node.split] < node.cutoff;
            boolean z2 = this.keys[this.index[i8]][node.split] >= node.cutoff;
            if (!z && !z2) {
                int i10 = this.index[i7];
                this.index[i7] = this.index[i8];
                this.index[i8] = i10;
                z2 = true;
                z = true;
            }
            if (z) {
                i7++;
                i9++;
            }
            if (z2) {
                i8--;
            }
        }
        node.lower = buildNode(i, i + i9);
        node.upper = buildNode(i + i9, i2);
        return node;
    }

    public void setIdenticalExcluded(boolean z) {
        this.identicalExcluded = z;
    }

    public boolean isIdenticalExcluded() {
        return this.identicalExcluded;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void search(double[] dArr, KDTree<E>.Node node, Neighbor<double[], E> neighbor) {
        KDTree<E>.Node node2;
        KDTree<E>.Node node3;
        if (!node.isLeaf()) {
            double d = dArr[node.split] - node.cutoff;
            if (d < 0.0d) {
                node2 = node.lower;
                node3 = node.upper;
            } else {
                node2 = node.upper;
                node3 = node.lower;
            }
            search(dArr, node2, neighbor);
            if (neighbor.distance >= d * d) {
                search(dArr, node3, neighbor);
                return;
            }
            return;
        }
        for (int i = node.index; i < node.index + node.count; i++) {
            if (dArr != this.keys[this.index[i]] || !this.identicalExcluded) {
                double squaredDistance = Math.squaredDistance(dArr, this.keys[this.index[i]]);
                if (squaredDistance < neighbor.distance) {
                    neighbor.key = this.keys[this.index[i]];
                    neighbor.value = this.data[this.index[i]];
                    neighbor.index = this.index[i];
                    neighbor.distance = squaredDistance;
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void search(double[] dArr, KDTree<E>.Node node, HeapSelect<Neighbor<double[], E>> heapSelect) {
        KDTree<E>.Node node2;
        KDTree<E>.Node node3;
        if (!node.isLeaf()) {
            double d = dArr[node.split] - node.cutoff;
            if (d < 0.0d) {
                node2 = node.lower;
                node3 = node.upper;
            } else {
                node2 = node.upper;
                node3 = node.lower;
            }
            search(dArr, node2, heapSelect);
            if (((Neighbor) heapSelect.peek()).distance >= d * d) {
                search(dArr, node3, heapSelect);
                return;
            }
            return;
        }
        for (int i = node.index; i < node.index + node.count; i++) {
            if (dArr != this.keys[this.index[i]] || !this.identicalExcluded) {
                double squaredDistance = Math.squaredDistance(dArr, this.keys[this.index[i]]);
                Neighbor neighbor = (Neighbor) heapSelect.peek();
                if (squaredDistance < neighbor.distance) {
                    neighbor.distance = squaredDistance;
                    neighbor.index = this.index[i];
                    neighbor.key = this.keys[this.index[i]];
                    neighbor.value = this.data[this.index[i]];
                    heapSelect.heapify();
                }
            }
        }
    }

    private void search(double[] dArr, KDTree<E>.Node node, double d, List<Neighbor<double[], E>> list) {
        KDTree<E>.Node node2;
        KDTree<E>.Node node3;
        if (node.isLeaf()) {
            for (int i = node.index; i < node.index + node.count; i++) {
                if (dArr != this.keys[this.index[i]] || !this.identicalExcluded) {
                    double distance = Math.distance(dArr, this.keys[this.index[i]]);
                    if (distance <= d) {
                        list.add(new Neighbor<>(this.keys[this.index[i]], this.data[this.index[i]], this.index[i], distance));
                    }
                }
            }
            return;
        }
        double d2 = dArr[node.split] - node.cutoff;
        if (d2 < 0.0d) {
            node2 = node.lower;
            node3 = node.upper;
        } else {
            node2 = node.upper;
            node3 = node.lower;
        }
        search(dArr, node2, d, list);
        if (d >= d2 * d2) {
            search(dArr, node3, d, list);
        }
    }

    @Override // smile.neighbor.NearestNeighborSearch
    public Neighbor<double[], E> nearest(double[] dArr) {
        Neighbor<double[], E> neighbor = new Neighbor<>(null, null, 0, Double.MAX_VALUE);
        search(dArr, this.root, neighbor);
        neighbor.distance = Math.sqrt(neighbor.distance);
        return neighbor;
    }

    @Override // smile.neighbor.KNNSearch
    public Neighbor<double[], E>[] knn(double[] dArr, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid k: " + i);
        }
        if (i > this.keys.length) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        Neighbor neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE);
        Neighbor<double[], E>[] neighborArr = (Neighbor[]) Array.newInstance(neighbor.getClass(), i);
        HeapSelect<Neighbor<double[], E>> heapSelect = new HeapSelect<>(neighborArr);
        for (int i2 = 0; i2 < i; i2++) {
            heapSelect.add(neighbor);
            neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE);
        }
        search(dArr, this.root, heapSelect);
        heapSelect.sort();
        for (int i3 = 0; i3 < neighborArr.length; i3++) {
            neighborArr[i3].distance = Math.sqrt(neighborArr[i3].distance);
        }
        return neighborArr;
    }

    @Override // smile.neighbor.RNNSearch
    public void range(double[] dArr, double d, List<Neighbor<double[], E>> list) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid radius: " + d);
        }
        search(dArr, this.root, d, list);
    }
}
