/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.sptree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.sptree.Cell;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SpTree
implements Serializable {
    public static final String workspaceExternal = "SPTREE_LOOP_EXTERNAL";
    private int D;
    private INDArray data;
    public static final int NODE_RATIO = 8000;
    private int N;
    private int size;
    private int cumSize;
    private Cell boundary;
    private INDArray centerOfMass;
    private SpTree parent;
    private int[] index;
    private int nodeCapacity;
    private int numChildren = 2;
    private boolean isLeaf = true;
    private Collection<INDArray> indices;
    private SpTree[] children;
    private static Logger log = LoggerFactory.getLogger(SpTree.class);
    private String similarityFunction = Distance.EUCLIDEAN.toString();

    public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, String similarityFunction) {
        this.init(parent, data, corner, width, indices, similarityFunction);
    }

    public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
        this.indices = indices;
        this.N = data.rows();
        this.D = data.columns();
        this.similarityFunction = similarityFunction;
        data = data.dup();
        INDArray meanY = data.mean(new int[]{0});
        INDArray minY = data.min(new int[]{0});
        INDArray maxY = data.max(new int[]{0});
        INDArray width = Nd4j.create((DataType)data.dataType(), (long[])meanY.shape());
        int i = 0;
        while ((long)i < width.length()) {
            width.putScalar((long)i, Math.max(maxY.getDouble((long)i) - meanY.getDouble((long)i), meanY.getDouble((long)i) - minY.getDouble((long)i)) + Nd4j.EPS_THRESHOLD);
            ++i;
        }
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            this.init(null, data, meanY, width, indices, similarityFunction);
            this.fill(this.N);
        }
    }

    public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
        this(parent, data, corner, width, indices, "euclidean");
    }

    public SpTree(INDArray data, Collection<INDArray> indices) {
        this(data, indices, "euclidean");
    }

    public SpTree(INDArray data) {
        this(data, new ArrayList<INDArray>());
    }

    public MemoryWorkspace workspace() {
        return null;
    }

    private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, String similarityFunction) {
        this.parent = parent;
        this.D = data.columns();
        this.N = data.rows();
        this.similarityFunction = similarityFunction;
        this.nodeCapacity = this.N % 8000;
        this.index = new int[this.nodeCapacity];
        for (int d = 1; d < this.D; ++d) {
            this.numChildren *= 2;
        }
        this.indices = indices;
        this.isLeaf = true;
        this.size = 0;
        this.cumSize = 0;
        this.children = new SpTree[this.numChildren];
        this.data = data;
        this.boundary = new Cell(this.D);
        this.boundary.setCorner(corner.dup());
        this.boundary.setWidth(width.dup());
        this.centerOfMass = Nd4j.create((DataType)data.dataType(), (long[])new long[]{this.D});
    }

    private boolean insert(int index) {
        int i;
        INDArray point = this.data.slice((long)index);
        if (!this.boundary.contains(point)) {
            return false;
        }
        ++this.cumSize;
        double mult1 = (double)(this.cumSize - 1) / (double)this.cumSize;
        double mult2 = 1.0 / (double)this.cumSize;
        this.centerOfMass.muli((Number)mult1);
        this.centerOfMass.addi(point.mul((Number)mult2));
        if (this.isLeaf() && this.size < this.nodeCapacity) {
            this.index[this.size] = index;
            this.indices.add(point);
            ++this.size;
            return true;
        }
        for (i = 0; i < this.size; ++i) {
            INDArray compPoint = this.data.slice((long)this.index[i]);
            if (!compPoint.equals(point)) continue;
            return true;
        }
        if (this.isLeaf()) {
            this.subDivide();
        }
        for (i = 0; i < this.numChildren; ++i) {
            if (!this.children[i].insert(index)) continue;
            return true;
        }
        throw new IllegalStateException("Shouldn't reach this state");
    }

    public void subDivide() {
        int i;
        INDArray newCorner = Nd4j.create((DataType)this.data.dataType(), (long[])new long[]{this.D});
        INDArray newWidth = Nd4j.create((DataType)this.data.dataType(), (long[])new long[]{this.D});
        for (i = 0; i < this.numChildren; ++i) {
            int div = 1;
            for (int d = 0; d < this.D; ++d) {
                newWidth.putScalar((long)d, 0.5 * this.boundary.width(d));
                if (i / div % 2 == 1) {
                    newCorner.putScalar((long)d, this.boundary.corner(d) - 0.5 * this.boundary.width(d));
                } else {
                    newCorner.putScalar((long)d, this.boundary.corner(d) + 0.5 * this.boundary.width(d));
                }
                div *= 2;
            }
            this.children[i] = new SpTree(this, this.data, newCorner, newWidth, this.indices);
        }
        for (i = 0; i < this.size; ++i) {
            boolean success = false;
            for (int j = 0; j < this.numChildren; ++j) {
                if (success) continue;
                success = this.children[j].insert(this.index[i]);
            }
            this.index[i] = -1;
        }
        this.size = 0;
        this.isLeaf = false;
    }

    public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
        INDArray buf = Nd4j.create((DataType)this.data.dataType(), (long[])new long[]{this.D});
        if (this.cumSize == 0 || this.isLeaf() && this.size == 1 && this.index[0] == pointIndex) {
            return;
        }
        this.data.slice((long)pointIndex).subi(this.centerOfMass, buf);
        double D = Nd4j.getBlasWrapper().dot(buf, buf);
        double maxWidth = this.boundary.width().maxNumber().doubleValue();
        if (this.isLeaf() || maxWidth / Math.sqrt(D) < theta) {
            double Q = 1.0 / (1.0 + D);
            double mult = (double)this.cumSize * Q;
            sumQ.addAndGet(mult);
            negativeForce.addi(buf.mul((Number)(mult *= Q)));
        } else {
            for (int i = 0; i < this.numChildren; ++i) {
                this.children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
            }
        }
    }

    public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
        if (!rowP.isVector()) {
            throw new IllegalArgumentException("RowP must be a vector");
        }
        Nd4j.exec((CustomOp)new BarnesEdgeForces(rowP, colP, valP, this.data, (long)N, posF));
    }

    public boolean isLeaf() {
        return this.isLeaf;
    }

    public boolean isCorrect() {
        for (int n = 0; n < this.size; ++n) {
            INDArray point = this.data.slice((long)this.index[n]);
            if (this.boundary.contains(point)) continue;
            return false;
        }
        if (!this.isLeaf()) {
            boolean correct = true;
            for (int i = 0; i < this.numChildren; ++i) {
                correct = correct && this.children[i].isCorrect();
            }
            return correct;
        }
        return true;
    }

    public int depth() {
        if (this.isLeaf()) {
            return 1;
        }
        int depth = 1;
        int maxChildDepth = 0;
        for (int i = 0; i < this.numChildren; ++i) {
            maxChildDepth = Math.max(maxChildDepth, this.children[0].depth());
        }
        return depth + maxChildDepth;
    }

    private void fill(int n) {
        if (this.indices.isEmpty() && this.parent == null) {
            for (int i = 0; i < n; ++i) {
                log.trace("Inserted " + i);
                this.insert(i);
            }
        } else {
            log.warn("Called fill already");
        }
    }

    public SpTree[] getChildren() {
        return this.children;
    }

    public int getD() {
        return this.D;
    }

    public INDArray getCenterOfMass() {
        return this.centerOfMass;
    }

    public Cell getBoundary() {
        return this.boundary;
    }

    public int[] getIndex() {
        return this.index;
    }

    public int getCumSize() {
        return this.cumSize;
    }

    public void setCumSize(int cumSize) {
        this.cumSize = cumSize;
    }

    public int getNumChildren() {
        return this.numChildren;
    }

    public void setNumChildren(int numChildren) {
        this.numChildren = numChildren;
    }
}

