/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering.lsh;

import java.util.Arrays;
import org.deeplearning4j.clustering.lsh.LSH;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RandomProjectionLSH
implements LSH {
    private int hashLength;
    private int numTables;
    private int inDimension;
    private double radius;
    INDArray randomProjection;
    INDArray index;
    INDArray indexData;

    @Override
    public String getDistanceMeasure() {
        return "cosinedistance";
    }

    private INDArray gaussianRandomMatrix(int[] shape, Random rng) {
        INDArray res = Nd4j.create((int[])shape);
        GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
        Nd4j.getExecutioner().exec((RandomOp)op1, rng);
        return res;
    }

    public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius) {
        this(hashLength, numTables, inDimension, radius, Nd4j.getRandom());
    }

    public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng) {
        this.hashLength = hashLength;
        this.numTables = numTables;
        this.inDimension = inDimension;
        this.radius = radius;
        this.randomProjection = this.gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng);
    }

    public INDArray entropy(INDArray data) {
        INDArray data2 = Nd4j.getExecutioner().exec((RandomOp)new GaussianDistribution(Nd4j.create((int[])new int[]{this.numTables, this.inDimension}), this.radius));
        INDArray norms = Nd4j.norm2((INDArray)data2.dup(), (int)-1);
        Preconditions.checkState((norms.rank() == 1 && norms.size(0) == (long)this.numTables ? 1 : 0) != 0, (String)"Expected norm2 to have shape [%s], is %ndShape", (Object)norms.size(0), (Object)norms);
        data2.diviColumnVector(norms);
        data2.addiRowVector(data);
        return data2;
    }

    public INDArray hash(INDArray data) {
        if (data.shape()[1] != (long)this.inDimension) {
            throw new ND4JIllegalStateException(String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d", Arrays.toString(data.shape()), this.inDimension));
        }
        INDArray projected = data.mmul(this.randomProjection);
        INDArray res = Nd4j.getExecutioner().exec((Op)new Sign(projected));
        return res;
    }

    @Override
    public void makeIndex(INDArray data) {
        this.index = this.hash(data);
        this.indexData = data;
    }

    INDArray rawBucketOf(INDArray query) {
        INDArray pattern = this.hash(query);
        INDArray res = Nd4j.zeros((DataType)DataType.BOOL, (long[])this.index.shape());
        Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastEqualTo(this.index, pattern, res, new int[]{-1}));
        return res.castTo(Nd4j.defaultFloatingPointType()).min(new int[]{-1});
    }

    @Override
    public INDArray bucket(INDArray query) {
        INDArray queryRes = this.rawBucketOf(query);
        if (this.numTables > 1) {
            INDArray entropyQueries = this.entropy(query);
            for (int i = 0; i < this.numTables; ++i) {
                INDArray row = entropyQueries.getRow((long)i, true);
                queryRes.addi(this.rawBucketOf(row));
            }
            BooleanIndexing.replaceWhere((INDArray)queryRes, (Number)1.0, (Condition)Conditions.greaterThan((Number)0.0));
        }
        return queryRes;
    }

    INDArray bucketData(INDArray query) {
        INDArray mask = this.bucket(query);
        int nRes = mask.sum(new int[]{0}).getInt(new int[]{0});
        INDArray res = Nd4j.create((int[])new int[]{nRes, this.inDimension});
        int j = 0;
        for (int i = 0; i < nRes; ++i) {
            while (mask.getInt(new int[]{++j}) == 0 && (long)j < mask.length() - 1L) {
            }
            if (mask.getInt(new int[]{j}) == 1) {
                res.putRow((long)i, this.indexData.getRow((long)j));
            }
            ++j;
        }
        return res;
    }

    @Override
    public INDArray search(INDArray query, double maxRange) {
        if (maxRange < 0.0) {
            throw new IllegalArgumentException("ANN search should have a positive maximum search radius");
        }
        INDArray bucketData = this.bucketData(query);
        INDArray distances = Transforms.allCosineDistances((INDArray)bucketData, (INDArray)query, (int[])new int[]{-1});
        INDArray[] idxs = Nd4j.sortWithIndices((INDArray)distances, (int)-1, (boolean)true);
        INDArray shuffleIndexes = idxs[0];
        INDArray sortedDistances = idxs[1];
        int accepted = 0;
        while ((long)accepted < sortedDistances.length()) {
            if ((double)sortedDistances.getInt(new int[]{accepted++}) <= maxRange) continue;
        }
        INDArray res = Nd4j.create((int[])new int[]{accepted, this.inDimension});
        for (int i = 0; i < accepted; ++i) {
            res.putRow((long)i, bucketData.getRow((long)shuffleIndexes.getInt(new int[]{i})));
        }
        return res;
    }

    @Override
    public INDArray search(INDArray query, int k) {
        if (k < 1) {
            throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor");
        }
        INDArray bucketData = this.bucketData(query);
        INDArray distances = Transforms.allCosineDistances((INDArray)bucketData, (INDArray)query, (int[])new int[]{-1});
        INDArray[] idxs = Nd4j.sortWithIndices((INDArray)distances, (int)-1, (boolean)true);
        INDArray shuffleIndexes = idxs[0];
        INDArray sortedDistances = idxs[1];
        long accepted = Math.min((long)k, sortedDistances.shape()[1]);
        INDArray res = Nd4j.create((long[])new long[]{accepted, this.inDimension});
        int i = 0;
        while ((long)i < accepted) {
            res.putRow((long)i, bucketData.getRow((long)shuffleIndexes.getInt(new int[]{i})));
            ++i;
        }
        return res;
    }

    @Override
    public int getHashLength() {
        return this.hashLength;
    }

    @Override
    public int getNumTables() {
        return this.numTables;
    }

    @Override
    public int getInDimension() {
        return this.inDimension;
    }

    public double getRadius() {
        return this.radius;
    }
}

