/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.Pair;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;

public class MatrixUtilsExtensions {
    private static final String SHAPE_STRING = "Matrix %s shape: %d x %d";

    private MatrixUtilsExtensions() {
        throw new IllegalStateException("Utility class");
    }

    public static RealVector vectorFromPredictionInput(PredictionInput p) {
        return MatrixUtils.createRealVector((double[])p.getFeatures().stream().mapToDouble(f -> f.getValue().asNumber()).toArray());
    }

    public static RealMatrix matrixFromPredictionInput(List<PredictionInput> ps) {
        return MatrixUtils.createRealMatrix((double[][])((double[][])ps.stream().map(p -> p.getFeatures().stream().mapToDouble(f -> f.getValue().asNumber()).toArray()).toArray(x$0 -> new double[x$0][])));
    }

    public static RealVector vectorFromPredictionOutput(PredictionOutput p) {
        return MatrixUtils.createRealVector((double[])p.getOutputs().stream().mapToDouble(f -> f.getValue().asNumber()).toArray());
    }

    public static RealMatrix matrixFromPredictionOutput(List<PredictionOutput> ps) {
        return MatrixUtils.createRealMatrix((double[][])((double[][])ps.stream().map(p -> p.getOutputs().stream().mapToDouble(o -> o.getValue().asNumber()).toArray()).toArray(x$0 -> new double[x$0][])));
    }

    public static RealMatrix getPsuedoInverse(RealMatrix a) {
        SingularValueDecomposition svd = new SingularValueDecomposition(a);
        RealMatrix u = svd.getU();
        RealMatrix v = svd.getV();
        RealMatrix sigma = svd.getS();
        for (int i = 0; i < sigma.getRowDimension(); ++i) {
            double entry = sigma.getEntry(i, i);
            if (entry > 1.0E-6) {
                sigma.setEntry(i, i, 1.0 / entry);
                continue;
            }
            sigma.setEntry(i, i, 0.0);
        }
        sigma = sigma.transpose();
        return v.multiply(sigma.multiply(u.transpose()));
    }

    public static RealMatrix safeInvert(RealMatrix a) {
        try {
            return MatrixUtils.inverse((RealMatrix)a, (double)1.0E-6);
        }
        catch (SingularMatrixException e) {
            return MatrixUtilsExtensions.getPsuedoInverse(a);
        }
    }

    public static RealVector rowSum(RealMatrix m) {
        RealVector out = MatrixUtils.createRealVector((double[])new double[m.getColumnDimension()]);
        for (int i = 0; i < m.getRowDimension(); ++i) {
            out = out.add(m.getRowVector(i));
        }
        return out;
    }

    public static RealVector colSum(RealMatrix m) {
        RealVector out = MatrixUtils.createRealVector((double[])new double[m.getRowDimension()]);
        for (int i = 0; i < m.getColumnDimension(); ++i) {
            out = out.add(m.getColumnVector(i));
        }
        return out;
    }

    public static RealVector rowSquareSum(RealMatrix m) {
        RealVector out = MatrixUtils.createRealVector((double[])new double[m.getColumnDimension()]);
        for (int i = 0; i < m.getRowDimension(); ++i) {
            RealVector rv = m.getRowVector(i);
            out = out.add(rv.ebeMultiply(rv));
        }
        return out;
    }

    public static RealMatrix vectorDifference(RealMatrix m, RealVector v, Axis a) {
        if (a == Axis.ROW) {
            RealMatrix out = m.createMatrix(m.getRowDimension(), m.getColumnDimension());
            for (int i = 0; i < m.getRowDimension(); ++i) {
                out.setRowVector(i, m.getRowVector(i).subtract(v));
            }
            return out;
        }
        RealMatrix out = m.createMatrix(m.getRowDimension(), m.getColumnDimension());
        for (int i = 0; i < m.getColumnDimension(); ++i) {
            out.setColumnVector(i, m.getColumnVector(i).subtract(v));
        }
        return out;
    }

    public static RealMatrix vectorRowProduct(RealMatrix m, RealVector v) {
        int vSize;
        int mRows = m.getRowDimension();
        int mCols = m.getColumnDimension();
        if (mCols != (vSize = v.getDimension())) {
            throw new IllegalArgumentException("Columns of matrix A must match size of vector b" + String.format(SHAPE_STRING, "A", mRows, mCols) + String.format("Size of vector b: %d", vSize));
        }
        RealMatrix out = MatrixUtils.createRealMatrix((int)mRows, (int)mCols);
        for (int row = 0; row < mRows; ++row) {
            out.setRowVector(row, m.getRowVector(row).ebeMultiply(v));
        }
        return out;
    }

    public static RealMatrix map(RealMatrix m, UnivariateFunction op) {
        RealMatrix output = m.copy();
        for (int i = 0; i < m.getRowDimension(); ++i) {
            output.setRowVector(i, m.getRowVector(i).map(op));
        }
        return output;
    }

    public static RealMatrix getCols(RealMatrix x, List<Integer> idxs) {
        if (idxs.isEmpty()) {
            throw new IllegalArgumentException("Empty column idxs passed to getCols");
        }
        RealMatrix out = MatrixUtils.createRealMatrix((double[][])new double[x.getRowDimension()][idxs.size()]);
        for (int col = 0; col < idxs.size(); ++col) {
            if (idxs.get(col) >= x.getColumnDimension() || idxs.get(col) < 0) {
                throw new IllegalArgumentException(String.format("Column index %d output bounds, matrix only has %d column(s)", idxs.get(col), x.getColumnDimension()));
            }
            out.setColumnVector(col, x.getColumnVector(idxs.get(col).intValue()));
        }
        return out;
    }

    public static RealMatrix matrixDot(RealMatrix a, RealMatrix b) {
        int aRows = a.getRowDimension();
        int aCols = a.getColumnDimension();
        int bRows = b.getRowDimension();
        int bCols = b.getColumnDimension();
        if (aCols != bRows) {
            throw new IllegalArgumentException("Columns of matrix A must match rows of matrix B" + String.format(SHAPE_STRING, "A", aRows, aCols) + String.format(SHAPE_STRING, "B", bRows, bCols));
        }
        RealMatrix out = MatrixUtils.createRealMatrix((int)aRows, (int)bCols);
        for (int row = 0; row < aRows; ++row) {
            for (int col = 0; col < bCols; ++col) {
                out.setEntry(row, col, a.getRowVector(row).dotProduct(b.getColumnVector(col)));
            }
        }
        return out;
    }

    public static RealMatrix batchRowMean(RealMatrix m, int batchSize) {
        int mRows = m.getRowDimension();
        if (mRows % batchSize != 0) {
            throw new IllegalArgumentException(String.format("# Rows of matrix A (%d) must be an exact product of batch size %d", mRows, batchSize));
        }
        RealMatrix result = MatrixUtils.createRealMatrix((int)(mRows / batchSize), (int)m.getColumnDimension());
        for (int i = 0; i < m.getRowDimension(); i += batchSize) {
            RealVector batch = m.getRowVector(i);
            for (int j = i + 1; j < i + batchSize; ++j) {
                batch = batch.add(m.getRowVector(j));
            }
            batch.mapDivideToSelf((double)batchSize);
            result.setRowVector(i / batchSize, batch);
        }
        return result;
    }

    public static Pair<RealMatrix, RealVector> jointATBATandATBC(RealMatrix A, RealVector B, RealVector C) {
        RealMatrix atbat = MatrixUtils.createRealMatrix((int)A.getColumnDimension(), (int)A.getColumnDimension());
        RealVector atbc = MatrixUtils.createRealVector((double[])new double[A.getColumnDimension()]);
        for (int i = 0; i < A.getColumnDimension(); ++i) {
            double atbcElement = 0.0;
            for (int j = 0; j < A.getColumnDimension(); ++j) {
                double atbatElement = 0.0;
                for (int k = 0; k < B.getDimension(); ++k) {
                    atbatElement += A.getEntry(k, j) * B.getEntry(k) * A.getEntry(k, i);
                    if (j != 0) continue;
                    atbcElement += A.getEntry(k, i) * B.getEntry(k) * C.getEntry(k);
                }
                atbat.setEntry(i, j, atbatElement);
                if (j != 0) continue;
                atbc.setEntry(i, atbcElement);
            }
        }
        return new Pair((Object)atbat, (Object)atbc);
    }

    public static double minPos(RealVector v) {
        double minPos = Double.MAX_VALUE;
        for (int i = 0; i < v.getDimension(); ++i) {
            double vI = v.getEntry(i);
            if (!(vI > 0.0) || !(vI < minPos)) continue;
            minPos = vI;
        }
        return minPos;
    }

    public static List<Integer> nonzero(RealVector v) {
        ArrayList<Integer> output = new ArrayList<Integer>();
        for (int i = 0; i < v.getDimension(); ++i) {
            if (v.getEntry(i) == 0.0) continue;
            output.add(i);
        }
        return output;
    }

    public static double variance(RealVector v) {
        double mean = Arrays.stream(v.toArray()).sum() / (double)v.getDimension();
        return Arrays.stream(v.map(a -> Math.pow(a - mean, 2.0)).toArray()).sum() / (double)v.getDimension();
    }

    public static double sum(RealVector v) {
        return Arrays.stream(v.toArray()).sum();
    }

    public static void swap(RealMatrix x, int i, int j) {
        double[] tmp = x.getRow(i);
        x.setRow(i, x.getRow(j));
        x.setRow(j, tmp);
    }

    public static void swap(RealVector x, int i, int j) {
        double tmp = x.getEntry(i);
        x.setEntry(i, x.getEntry(j));
        x.setEntry(j, tmp);
    }

    public static void swap(int[] x, int i, int j) {
        int tmp = x[i];
        x[i] = x[j];
        x[j] = tmp;
    }

    public static <T> List<T> getColumn(T[][] m, int idx) {
        ArrayList<T> out = new ArrayList<T>();
        for (int i = 0; i < m.length; ++i) {
            out.add(m[i][idx]);
        }
        return out;
    }

    public static enum Axis {
        ROW,
        COLUMN;

    }
}

