/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.jblas.SimpleBlas;
import org.jblas.ranges.IntervalRange;
import org.jblas.ranges.Range;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MatrixUtil {
    private static Logger log = LoggerFactory.getLogger(MatrixUtil.class);

    public static void complainAboutMissMatchedMatrices(DoubleMatrix d1, DoubleMatrix d2) {
        if (d1 == null || d2 == null) {
            throw new IllegalArgumentException("No null matrices allowed");
        }
        if (d1.rows != d2.rows) {
            throw new IllegalArgumentException("Matrices must have same rows");
        }
    }

    public static void max(double minNumber, DoubleMatrix matrix) {
        for (int i = 0; i < matrix.length; ++i) {
            matrix.put(i, Math.max(0.0, matrix.get(i)));
        }
    }

    public static void scaleByMax(DoubleMatrix toScale) {
        DoubleMatrix scale = toScale.rowMaxs();
        for (int i = 0; i < toScale.rows; ++i) {
            double scaleBy = scale.get(i, 0);
            toScale.putRow(i, toScale.getRow(i).divi(scaleBy));
        }
    }

    public static DoubleMatrix variance(DoubleMatrix input) {
        DoubleMatrix means = input.columnMeans();
        DoubleMatrix diff = MatrixFunctions.pow((DoubleMatrix)input.subRowVector(means), (double)2.0);
        DoubleMatrix variance = diff.columnMeans().div((double)input.rows);
        return variance;
    }

    public static double singlePixelConvolution(DoubleMatrix input, int x, int y, DoubleMatrix k, int kernelWidth, int kernelHeight) {
        double output = 0.0;
        for (int i = 0; i < kernelWidth; ++i) {
            for (int j = 0; j < kernelHeight; ++j) {
                output += input.get(x + i, y + j) * k.get(i, j);
            }
        }
        return output;
    }

    public static DoubleMatrix convolution2D(DoubleMatrix input, int width, int height, DoubleMatrix kernel, int kernelWidth, int kernelHeight) {
        int smallWidth = width - kernelWidth + 1;
        int smallHeight = height - kernelHeight + 1;
        DoubleMatrix output = DoubleMatrix.zeros((int)smallWidth, (int)smallHeight);
        for (int i = 0; i < smallWidth; ++i) {
            for (int j = 0; j < smallHeight; ++j) {
                output.put(i, j, MatrixUtil.singlePixelConvolution(input, i, j, kernel, kernelWidth, kernelHeight));
            }
        }
        return output;
    }

    public static DataSet xorData(int n) {
        DoubleMatrix x = DoubleMatrix.rand((int)n, (int)2);
        x = x.gti(0.5);
        DoubleMatrix y = DoubleMatrix.zeros((int)n, (int)2);
        for (int i = 0; i < x.rows; ++i) {
            if (x.get(i, 0) == x.get(i, 1)) {
                y.put(i, 0, 1.0);
                continue;
            }
            y.put(i, 1, 1.0);
        }
        return new DataSet(x, y);
    }

    public static DataSet xorData(int n, int columns) {
        DoubleMatrix x = DoubleMatrix.rand((int)n, (int)columns);
        x = x.gti(0.5);
        DoubleMatrix x2 = DoubleMatrix.rand((int)n, (int)columns);
        x2 = x2.gti(0.5);
        DoubleMatrix eq = x.eq(x2).eq(DoubleMatrix.zeros((int)n, (int)columns));
        int median = columns / 2;
        DoubleMatrix outcomes = new DoubleMatrix(n, 2);
        for (int i = 0; i < outcomes.rows; ++i) {
            DoubleMatrix left = eq.get(i, (Range)new IntervalRange(0, median));
            DoubleMatrix right = eq.get(i, (Range)new IntervalRange(median, columns));
            if (left.sum() > right.sum()) {
                outcomes.put(i, 0, 1.0);
                continue;
            }
            outcomes.put(i, 1, 1.0);
        }
        return new DataSet(eq, outcomes);
    }

    public static double magnitude(DoubleMatrix vec) {
        double sum_mag = 0.0;
        for (int i = 0; i < vec.length; ++i) {
            sum_mag += vec.get(i) * vec.get(i);
        }
        return Math.sqrt(sum_mag);
    }

    public static DoubleMatrix unroll(DoubleMatrix d) {
        DoubleMatrix ret = new DoubleMatrix(1, d.length);
        for (int i = 0; i < d.length; ++i) {
            ret.put(i, d.get(i));
        }
        return ret;
    }

    public static DoubleMatrix outcomes(DoubleMatrix d) {
        DoubleMatrix ret = new DoubleMatrix(d.rows, 1);
        for (int i = 0; i < d.rows; ++i) {
            ret.put(i, (double)SimpleBlas.iamax((DoubleMatrix)d.getRow(i)));
        }
        return ret;
    }

    public static double cosineSim(DoubleMatrix d1, DoubleMatrix d2) {
        d1 = MatrixUtil.unitVec(d1);
        d2 = MatrixUtil.unitVec(d2);
        double ret = d1.dot(d2);
        return ret;
    }

    public static DoubleMatrix normalize(DoubleMatrix input) {
        double min = input.min();
        double max = input.max();
        return input.subi(min).divi(max - min);
    }

    public static double cosine(DoubleMatrix matrix) {
        return 1.0 * Math.sqrt(MatrixFunctions.pow((DoubleMatrix)matrix, (double)2.0).sum());
    }

    public static DoubleMatrix unitVec(DoubleMatrix toScale) {
        double length = toScale.norm2();
        if (length > 0.0) {
            return SimpleBlas.scal((double)(1.0 / length), (DoubleMatrix)toScale);
        }
        return toScale;
    }

    public static DoubleMatrix uniform(RandomGenerator rng, int rows, int columns) {
        UniformRealDistribution uDist = new UniformRealDistribution(rng, 0.0, 1.0);
        DoubleMatrix U = new DoubleMatrix(rows, columns);
        for (int i = 0; i < U.rows; ++i) {
            for (int j = 0; j < U.columns; ++j) {
                U.put(i, j, uDist.sample());
            }
        }
        return U;
    }

    public static DoubleMatrix normal(RandomGenerator rng, DoubleMatrix mean, double sigma) {
        DoubleMatrix U = new DoubleMatrix(mean.rows, mean.columns);
        for (int i = 0; i < U.rows; ++i) {
            for (int j = 0; j < U.columns; ++j) {
                NormalDistribution reals = new NormalDistribution(mean.get(i, j), Math.sqrt(sigma));
                U.put(i, j, reals.sample());
            }
        }
        return U;
    }

    public static DoubleMatrix normal(RandomGenerator rng, DoubleMatrix mean, DoubleMatrix variance) {
        DoubleMatrix std = MatrixFunctions.sqrt((DoubleMatrix)variance);
        for (int i = 0; i < variance.length; ++i) {
            if (!(variance.get(i) <= 0.0)) continue;
            variance.put(i, 1.0E-4);
        }
        DoubleMatrix U = new DoubleMatrix(mean.rows, mean.columns);
        for (int i = 0; i < U.rows; ++i) {
            for (int j = 0; j < U.columns; ++j) {
                NormalDistribution reals = new NormalDistribution(mean.get(i, j), std.get(j));
                U.put(i, j, reals.sample());
            }
        }
        return U;
    }

    public static DoubleMatrix normal(RandomGenerator rng, DoubleMatrix standardDeviations) {
        DoubleMatrix U = new DoubleMatrix(standardDeviations.rows, standardDeviations.columns);
        for (int i = 0; i < U.rows; ++i) {
            for (int j = 0; j < U.columns; ++j) {
                NormalDistribution reals = new NormalDistribution(0.0, standardDeviations.get(i, j));
                U.put(i, j, reals.sample());
            }
        }
        return U;
    }

    public static boolean isValidOutcome(DoubleMatrix out) {
        boolean found = false;
        for (int col = 0; col < out.length; ++col) {
            if (!(out.get(col) > 0.0)) continue;
            found = true;
            break;
        }
        return found;
    }

    public static double min(DoubleMatrix matrix) {
        double ret = matrix.get(0);
        for (int i = 0; i < matrix.length; ++i) {
            if (!(matrix.get(i) < ret)) continue;
            ret = matrix.get(i);
        }
        return ret;
    }

    public static double max(DoubleMatrix matrix) {
        double ret = matrix.get(0);
        for (int i = 0; i < matrix.length; ++i) {
            if (!(matrix.get(i) > ret)) continue;
            ret = matrix.get(i);
        }
        return ret;
    }

    public static void ensureValidOutcomeMatrix(DoubleMatrix out) {
        boolean found = false;
        for (int col = 0; col < out.length; ++col) {
            if (!(out.get(col) > 0.0)) continue;
            found = true;
            break;
        }
        if (!found) {
            log.warn("Found invalid matrix assuming; nothing which means adding a 1 to the first spot");
            out.put(0, 1.0);
        }
    }

    public static void assertIntMatrix(DoubleMatrix matrix) {
        for (int i = 0; i < matrix.length; ++i) {
            int cast = (int)matrix.get(i);
            if ((double)cast == matrix.get(i)) continue;
            throw new IllegalArgumentException("Found something that is not an integer at linear index " + i);
        }
    }

    public static boolean isInfinite(DoubleMatrix test) {
        DoubleMatrix nan = test.isInfinite();
        for (int i = 0; i < nan.length; ++i) {
            if (!(nan.get(i) > 0.0)) continue;
            return true;
        }
        return false;
    }

    public static boolean isNaN(DoubleMatrix test) {
        for (int i = 0; i < test.length; ++i) {
            if (!Double.isNaN(test.get(i))) continue;
            return true;
        }
        return false;
    }

    public static void discretizeColumns(DoubleMatrix toDiscretize, int numBins) {
        DoubleMatrix columnMaxes = toDiscretize.columnMaxs();
        DoubleMatrix columnMins = toDiscretize.columnMins();
        for (int i = 0; i < toDiscretize.columns; ++i) {
            double min = columnMins.get(i);
            double max = columnMaxes.get(i);
            DoubleMatrix col = toDiscretize.getColumn(i);
            DoubleMatrix newCol = new DoubleMatrix(col.length);
            for (int j = 0; j < col.length; ++j) {
                int bin = MathUtils.discretize(col.get(j), min, max, numBins);
                newCol.put(j, (double)bin);
            }
            toDiscretize.putColumn(i, newCol);
        }
    }

    public static DoubleMatrix roundToTheNearest(DoubleMatrix d, int num) {
        DoubleMatrix ret = d.mul((double)num);
        for (int i = 0; i < d.rows; ++i) {
            for (int j = 0; j < d.columns; ++j) {
                double d2 = d.get(i, j);
                double newNum = MathUtils.roundDouble(d2, num);
                ret.put(i, j, newNum);
            }
        }
        return ret;
    }

    public static void columnNormalizeBySum(DoubleMatrix x) {
        for (int i = 0; i < x.columns; ++i) {
            x.putColumn(i, x.getColumn(i).div(x.getColumn(i).sum()));
        }
    }

    public static DoubleMatrix toOutcomeVector(int index, int numOutcomes) {
        int[] nums = new int[numOutcomes];
        nums[index] = 1;
        return MatrixUtil.toMatrix(nums);
    }

    public static DoubleMatrix toMatrix(int[][] arr) {
        DoubleMatrix d = new DoubleMatrix(arr.length, arr[0].length);
        for (int i = 0; i < arr.length; ++i) {
            for (int j = 0; j < arr[i].length; ++j) {
                d.put(i, j, (double)arr[i][j]);
            }
        }
        return d;
    }

    public static DoubleMatrix toMatrix(int[] arr) {
        DoubleMatrix d = new DoubleMatrix(arr.length);
        for (int i = 0; i < arr.length; ++i) {
            d.put(i, (double)arr[i]);
        }
        d.reshape(1, d.length);
        return d;
    }

    public static DoubleMatrix add(DoubleMatrix a, DoubleMatrix b) {
        return a.addi(b);
    }

    public static DoubleMatrix softmax(DoubleMatrix input) {
        DoubleMatrix max = input.rowMaxs();
        DoubleMatrix diff = MatrixFunctions.exp((DoubleMatrix)input.subColumnVector(max));
        diff.diviColumnVector(diff.rowSums());
        return diff;
    }

    public static DoubleMatrix mean(DoubleMatrix input, int axis) {
        DoubleMatrix ret = new DoubleMatrix(input.rows, 1);
        if (axis == 0) {
            return input.columnMeans();
        }
        if (axis == 1) {
            return ret.rowMeans();
        }
        return ret;
    }

    public static DoubleMatrix sum(DoubleMatrix input, int axis) {
        DoubleMatrix ret = new DoubleMatrix(input.rows, 1);
        if (axis == 0) {
            for (int i = 0; i < input.columns; ++i) {
                ret.put(i, input.getColumn(i).sum());
            }
            return ret;
        }
        if (axis == 1) {
            for (int i = 0; i < input.rows; ++i) {
                ret.put(i, input.getRow(i).sum());
            }
            return ret;
        }
        for (int i = 0; i < input.rows; ++i) {
            ret.put(i, input.getRow(i).sum());
        }
        return ret;
    }

    public static DoubleMatrix binomial(DoubleMatrix p, int n, RandomGenerator rng) {
        DoubleMatrix ret = new DoubleMatrix(p.rows, p.columns);
        for (int i = 0; i < ret.length; ++i) {
            ret.put(i, (double)MathUtils.binomial(rng, n, p.get(i)));
        }
        return ret;
    }

    public static DoubleMatrix columnWiseMean(DoubleMatrix x, int axis) {
        DoubleMatrix ret = DoubleMatrix.zeros((int)x.columns);
        for (int i = 0; i < x.columns; ++i) {
            ret.put(i, x.getColumn(axis).mean());
        }
        return ret;
    }

    public static DoubleMatrix avg(DoubleMatrix ... matrices) {
        if (matrices == null) {
            return null;
        }
        if (matrices.length == 1) {
            return matrices[0];
        }
        DoubleMatrix ret = matrices[0];
        for (int i = 1; i < matrices.length; ++i) {
            ret = ret.add(matrices[i]);
        }
        ret = ret.div((double)matrices.length);
        return ret;
    }

    public static int maxIndex(DoubleMatrix matrix) {
        double max = matrix.max();
        for (int j = 0; j < matrix.length; ++j) {
            if (matrix.get(j) != max) continue;
            return j;
        }
        return -1;
    }

    public static DoubleMatrix sigmoid(DoubleMatrix x) {
        DoubleMatrix ones = DoubleMatrix.ones((int)x.rows, (int)x.columns);
        return ones.div(ones.add(MatrixFunctions.exp((DoubleMatrix)x.neg())));
    }

    public static DoubleMatrix dot(DoubleMatrix a, DoubleMatrix b) {
        boolean isScalar;
        boolean bl = isScalar = a.isColumnVector() || a.isRowVector() && b.isColumnVector() || b.isRowVector();
        if (isScalar) {
            return DoubleMatrix.scalar((double)a.dot(b));
        }
        return a.mmul(b);
    }

    public static DoubleMatrix out(DoubleMatrix a, DoubleMatrix b) {
        return a.mmul(b);
    }

    public static DoubleMatrix scalarMinus(double scalar, DoubleMatrix ep) {
        DoubleMatrix d = new DoubleMatrix(ep.rows, ep.columns);
        d.addi(scalar);
        return d.sub(ep);
    }

    public static DoubleMatrix oneMinus(DoubleMatrix ep) {
        return DoubleMatrix.ones((int)ep.rows, (int)ep.columns).sub(ep);
    }

    public static DoubleMatrix oneDiv(DoubleMatrix ep) {
        for (int i = 0; i < ep.rows; ++i) {
            for (int j = 0; j < ep.columns; ++j) {
                if (ep.get(i, j) != 0.0) continue;
                ep.put(i, j, 0.01);
            }
        }
        return DoubleMatrix.ones((int)ep.rows, (int)ep.columns).div(ep);
    }

    public static DoubleMatrix columnStd(DoubleMatrix m) {
        DoubleMatrix ret = new DoubleMatrix(1, m.columns);
        StandardDeviation std = new StandardDeviation();
        for (int i = 0; i < m.columns; ++i) {
            double result = std.evaluate(m.getColumn((int)i).data);
            ret.put(i, result);
        }
        return ret;
    }

    public static DoubleMatrix rowStd(DoubleMatrix m) {
        StandardDeviation std = new StandardDeviation();
        DoubleMatrix ret = new DoubleMatrix(1, m.columns);
        for (int i = 0; i < m.rows; ++i) {
            double result = std.evaluate(m.getRow((int)i).data);
            ret.put(i, result);
        }
        return ret;
    }

    public static double meanSquaredError(DoubleMatrix input, DoubleMatrix other) {
        if (input.length != other.length) {
            throw new IllegalArgumentException("Matrices must be same length");
        }
        SimpleRegression r = new SimpleRegression();
        r.addData((double[][])new double[][]{input.data, other.data});
        return r.getMeanSquareError();
    }

    public static DoubleMatrix log(DoubleMatrix vals) {
        DoubleMatrix ret = new DoubleMatrix(vals.rows, vals.columns);
        for (int i = 0; i < vals.length; ++i) {
            double logVal = Math.log(vals.get(i));
            if (!Double.isNaN(logVal) && !Double.isInfinite(logVal)) {
                ret.put(i, logVal);
                continue;
            }
            ret.put(i, 1.0E-6);
        }
        return ret;
    }

    public static double sumSquaredError(DoubleMatrix input, DoubleMatrix other) {
        if (input.length != other.length) {
            throw new IllegalArgumentException("Matrices must be same length");
        }
        SimpleRegression r = new SimpleRegression();
        r.addData((double[][])new double[][]{input.data, other.data});
        return r.getSumSquaredErrors();
    }

    public static void normalizeMatrix(DoubleMatrix toNormalize) {
        DoubleMatrix columnMeans = toNormalize.columnMeans();
        toNormalize.subiRowVector(columnMeans);
        DoubleMatrix std = MatrixUtil.columnStd(toNormalize);
        std.addi(1.0E-6);
        toNormalize.diviRowVector(std);
    }

    public static DoubleMatrix normalizeByColumnSums(DoubleMatrix m) {
        DoubleMatrix columnSums = m.columnSums();
        for (int i = 0; i < m.columns; ++i) {
            m.putColumn(i, m.getColumn(i).div(columnSums.get(i)));
        }
        return m;
    }

    public static DoubleMatrix columnStdDeviation(DoubleMatrix m) {
        DoubleMatrix ret = new DoubleMatrix(1, m.columns);
        for (int i = 0; i < ret.length; ++i) {
            StandardDeviation dev = new StandardDeviation();
            double std = dev.evaluate(m.getColumn(i).toArray());
            ret.put(i, std);
        }
        return ret;
    }

    public static DoubleMatrix divColumnsByStDeviation(DoubleMatrix m) {
        DoubleMatrix std = MatrixUtil.columnStdDeviation(m);
        for (int i = 0; i < m.columns; ++i) {
            m.putColumn(i, m.getColumn(i).div(std.get(i)));
        }
        return m;
    }

    public static DoubleMatrix normalizeByColumnMeans(DoubleMatrix m) {
        DoubleMatrix columnMeans = m.columnMeans();
        for (int i = 0; i < m.columns; ++i) {
            m.putColumn(i, m.getColumn(i).sub(columnMeans.get(i)));
        }
        return m;
    }

    public static DoubleMatrix normalizeByRowSums(DoubleMatrix m) {
        DoubleMatrix rowSums = m.rowSums();
        for (int i = 0; i < m.rows; ++i) {
            m.putRow(i, m.getRow(i).div(rowSums.get(i)));
        }
        return m;
    }
}

