package org.apache.sysds.runtime.matrix.data;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibCommonsMath.class */
public class LibCommonsMath {
    private static final Log LOG = LogFactory.getLog(LibCommonsMath.class.getName());
    private static final double RELATIVE_SYMMETRY_THRESHOLD = 1.0E-6d;
    private static final double EIGEN_LAMBDA = 1.0E-8d;

    private LibCommonsMath() {
    }

    public static boolean isSupportedUnaryOperation(String str) {
        return str.equals("inverse") || str.equals("cholesky");
    }

    public static boolean isSupportedMultiReturnOperation(String str) {
        return str.equals(GPUInstruction.MISC_TIMER_QR) || str.equals("lu") || str.equals("eigen") || str.equals("svd");
    }

    public static boolean isSupportedMatrixMatrixOperation(String str) {
        return str.equals("solve");
    }

    public static MatrixBlock unaryOperations(MatrixBlock matrixBlock, String str) {
        Array2DRowRealMatrix convertToArray2DRowRealMatrix = DataConverter.convertToArray2DRowRealMatrix(matrixBlock);
        if (str.equals("inverse")) {
            return computeMatrixInverse(convertToArray2DRowRealMatrix);
        }
        if (str.equals("cholesky")) {
            return computeCholesky(convertToArray2DRowRealMatrix);
        }
        return null;
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock matrixBlock, String str) {
        return multiReturnOperations(matrixBlock, str, 1, 1L);
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock matrixBlock, String str, int i, int i2, double d) {
        return str.equals("eigen_qr") ? computeEigenQR(matrixBlock, i2, d, i) : multiReturnOperations(matrixBlock, str, i, 1L);
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock matrixBlock, String str, int i, long j) {
        if (str.equals(GPUInstruction.MISC_TIMER_QR)) {
            return computeQR(matrixBlock);
        }
        if (str.equals("qr2")) {
            return computeQR2(matrixBlock, i);
        }
        if (str.equals("lu")) {
            return computeLU(matrixBlock);
        }
        if (str.equals("eigen")) {
            return computeEigen(matrixBlock);
        }
        if (str.equals("eigen_lanczos")) {
            return computeEigenLanczos(matrixBlock, i, j);
        }
        if (str.equals("eigen_qr")) {
            return computeEigenQR(matrixBlock, i);
        }
        if (str.equals("svd")) {
            return computeSvd(matrixBlock);
        }
        return null;
    }

    public static MatrixBlock matrixMatrixOperations(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, String str) {
        if (!str.equals("solve")) {
            return null;
        }
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("The A matrix, in solve(A,b) should have squared dimensions.");
        }
        return computeSolve(matrixBlock, matrixBlock2);
    }

    private static MatrixBlock computeSolve(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        BlockRealMatrix convertToBlockRealMatrix = DataConverter.convertToBlockRealMatrix(matrixBlock);
        return DataConverter.convertToMatrixBlock(new QRDecomposition(convertToBlockRealMatrix).getSolver().solve(DataConverter.convertToBlockRealMatrix(matrixBlock2)));
    }

    private static MatrixBlock[] computeQR(MatrixBlock matrixBlock) {
        QRDecomposition qRDecomposition = new QRDecomposition(DataConverter.convertToArray2DRowRealMatrix(matrixBlock));
        return new MatrixBlock[]{DataConverter.convertToMatrixBlock(qRDecomposition.getH().getData()), DataConverter.convertToMatrixBlock(qRDecomposition.getR().getData())};
    }

    private static MatrixBlock[] computeLU(MatrixBlock matrixBlock) {
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("LU Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + matrixBlock.getNumRows() + ", cols=" + matrixBlock.getNumColumns() + ")");
        }
        LUDecomposition lUDecomposition = new LUDecomposition(DataConverter.convertToArray2DRowRealMatrix(matrixBlock));
        return new MatrixBlock[]{DataConverter.convertToMatrixBlock(lUDecomposition.getP().getData()), DataConverter.convertToMatrixBlock(lUDecomposition.getL().getData()), DataConverter.convertToMatrixBlock(lUDecomposition.getU().getData())};
    }

    private static MatrixBlock[] computeEigen(MatrixBlock matrixBlock) {
        EigenDecomposition computeEigenRegularized;
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + matrixBlock.getNumRows() + ", cols=" + matrixBlock.getNumColumns() + ")");
        }
        try {
            computeEigenRegularized = new EigenDecomposition(DataConverter.convertToArray2DRowRealMatrix(matrixBlock));
        } catch (MaxCountExceededException e) {
            LOG.warn("Eigen: " + e.getMessage() + ". Falling back to regularized eigen factorization.");
            computeEigenRegularized = computeEigenRegularized(matrixBlock);
        }
        return sortEVs(computeEigenRegularized.getRealEigenvalues(), computeEigenRegularized.getV().getData());
    }

    private static EigenDecomposition computeEigenRegularized(MatrixBlock matrixBlock) {
        if (matrixBlock == null || matrixBlock.isEmptyBlock(false)) {
            throw new DMLRuntimeException("Invalid empty block");
        }
        MatrixBlock matrixBlock2 = new MatrixBlock(matrixBlock, false);
        DenseBlock denseBlock = matrixBlock2.getDenseBlock();
        for (int i = 0; i < matrixBlock2.rlen; i++) {
            double[] values = denseBlock.values(i);
            int pos = denseBlock.pos(i);
            for (int i2 = 0; i2 < matrixBlock2.clen; i2++) {
                int i3 = pos + i2;
                values[i3] = values[i3] + (Math.signum(values[pos + i2]) * EIGEN_LAMBDA);
            }
        }
        return new EigenDecomposition(DataConverter.convertToArray2DRowRealMatrix(matrixBlock2));
    }

    private static MatrixBlock[] computeSvd(MatrixBlock matrixBlock) {
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(DataConverter.convertToArray2DRowRealMatrix(matrixBlock));
        double[] singularValues = singularValueDecomposition.getSingularValues();
        RealMatrix u = singularValueDecomposition.getU();
        RealMatrix v = singularValueDecomposition.getV();
        MatrixBlock convertToMatrixBlock = DataConverter.convertToMatrixBlock(u.getData());
        MatrixBlock convertToMatrixBlock2 = DataConverter.convertToMatrixBlock(singularValues, true);
        return new MatrixBlock[]{convertToMatrixBlock, LibMatrixReorg.diag(convertToMatrixBlock2, new MatrixBlock(convertToMatrixBlock2.rlen, convertToMatrixBlock2.rlen, true)), DataConverter.convertToMatrixBlock(v.getData())};
    }

    private static MatrixBlock computeMatrixInverse(Array2DRowRealMatrix array2DRowRealMatrix) {
        if (array2DRowRealMatrix.isSquare()) {
            return DataConverter.convertToMatrixBlock(new QRDecomposition(array2DRowRealMatrix).getSolver().getInverse().getData());
        }
        throw new DMLRuntimeException("Input to inv() must be square matrix -- given: a " + array2DRowRealMatrix.getRowDimension() + "x" + array2DRowRealMatrix.getColumnDimension() + " matrix.");
    }

    private static MatrixBlock computeCholesky(Array2DRowRealMatrix array2DRowRealMatrix) {
        if (array2DRowRealMatrix.isSquare()) {
            return DataConverter.convertToMatrixBlock(new CholeskyDecomposition(array2DRowRealMatrix, RELATIVE_SYMMETRY_THRESHOLD, 1.0E-10d).getL().getData());
        }
        throw new DMLRuntimeException("Input to cholesky() must be square matrix -- given: a " + array2DRowRealMatrix.getRowDimension() + "x" + array2DRowRealMatrix.getColumnDimension() + " matrix.");
    }

    private static MatrixBlock randNormalizedVect(int i, int i2, long j) {
        MatrixBlock randOperations = MatrixBlock.randOperations(i, 1, 1.0d, DataExpression.DEFAULT_DELIM_FILL_VALUE, 1.0d, "UNIFORM", j);
        MatrixBlock unaryOperations = randOperations.scalarOperations((ScalarOperator) new RightScalarOperator(Divide.getDivideFnObject(), randOperations.sum(), i2), (MatrixValue) new MatrixBlock()).unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject(Builtin.BuiltinCode.SQRT), i2, true), (MatrixValue) new MatrixBlock());
        if (Math.abs(unaryOperations.sumSq() - 1.0d) >= 1.0E-7d) {
            throw new DMLRuntimeException("v1 not correctly normalized (maybe try changing the seed)");
        }
        return unaryOperations;
    }

    private static MatrixBlock[] computeEigenLanczos(MatrixBlock matrixBlock, int i, long j) {
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("Lanczos algorithm and Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + matrixBlock.getNumRows() + ", cols=" + matrixBlock.getNumColumns() + ")");
        }
        int numRows = matrixBlock.getNumRows();
        MatrixBlock matrixBlock2 = new MatrixBlock(numRows, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        MatrixBlock randNormalizedVect = randNormalizedVect(numRows, i, j);
        MatrixBlock matrixBlock3 = new MatrixBlock(numRows, numRows, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        MatrixBlock matrixBlock4 = new MatrixBlock(numRows, numRows, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        ReorgOperator reorgOperator = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), i);
        TernaryOperator ternaryOperator = new TernaryOperator(MinusMultiply.getFnObject(), i);
        AggregateBinaryOperator matMultOperator = InstructionUtils.getMatMultOperator(i);
        RightScalarOperator rightScalarOperator = new RightScalarOperator(Divide.getDivideFnObject(), 1.0d, i);
        MatrixBlock matrixBlock5 = new MatrixBlock(1, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        for (int i2 = 0; i2 < numRows; i2++) {
            randNormalizedVect.putInto(matrixBlock4, 0, i2, false);
            MatrixBlock aggregateBinaryOperations = matrixBlock.aggregateBinaryOperations(matrixBlock, randNormalizedVect, matMultOperator);
            MatrixBlock aggregateBinaryOperations2 = aggregateBinaryOperations.aggregateBinaryOperations(randNormalizedVect.reorgOperations(reorgOperator, (MatrixValue) new MatrixBlock(), 0, 0, numRows), aggregateBinaryOperations, matMultOperator);
            if (i2 < numRows - 1) {
                MatrixBlock ternaryOperations = aggregateBinaryOperations.ternaryOperations(ternaryOperator, randNormalizedVect, aggregateBinaryOperations2, new MatrixBlock()).ternaryOperations(ternaryOperator, matrixBlock2, matrixBlock5, new MatrixBlock());
                matrixBlock5.setValue(0, 0, Math.sqrt(ternaryOperations.sumSq()));
                matrixBlock2.copy(randNormalizedVect);
                rightScalarOperator = rightScalarOperator.setConstant(matrixBlock5.getDouble(0, 0));
                ternaryOperations.scalarOperations(rightScalarOperator, (MatrixValue) randNormalizedVect);
                matrixBlock3.setValue(i2 + 1, i2, matrixBlock5.getValue(0, 0));
                matrixBlock3.setValue(i2, i2 + 1, matrixBlock5.getValue(0, 0));
            }
            matrixBlock3.setValue(i2, i2, aggregateBinaryOperations2.getValue(0, 0));
        }
        MatrixBlock[] multiReturnOperations = multiReturnOperations(matrixBlock3, "eigen");
        matrixBlock4.setNonZeros(numRows * numRows);
        multiReturnOperations[1] = matrixBlock4.aggregateBinaryOperations(matrixBlock4, multiReturnOperations[1], matMultOperator);
        return multiReturnOperations;
    }

    private static MatrixBlock[] computeQR2(MatrixBlock matrixBlock, int i) {
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("QR2 Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + matrixBlock.getNumRows() + ", cols=" + matrixBlock.getNumColumns() + ")");
        }
        int i2 = matrixBlock.rlen;
        MatrixBlock matrixBlock2 = new MatrixBlock();
        matrixBlock2.copy(matrixBlock);
        MatrixBlock matrixBlock3 = new MatrixBlock(i2, i2, true);
        for (int i3 = 0; i3 < i2; i3++) {
            matrixBlock3.setValue(i3, i3, 1.0d);
        }
        ReorgOperator reorgOperator = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), i);
        AggregateBinaryOperator matMultOperator = InstructionUtils.getMatMultOperator(i);
        BinaryOperator parseExtendedBinaryOperator = InstructionUtils.parseExtendedBinaryOperator(ProgramConverter.DASH);
        RightScalarOperator rightScalarOperator = new RightScalarOperator(Divide.getDivideFnObject(), 1.0d, i);
        LeftScalarOperator leftScalarOperator = new LeftScalarOperator(Multiply.getMultiplyFnObject(), 2.0d, i);
        for (int i4 = 0; i4 < i2; i4++) {
            MatrixBlock slice = matrixBlock2.slice(i4, i2 - 1, i4, i4);
            MatrixBlock matrixBlock4 = new MatrixBlock(i2 - i4, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixBlock4.copy(slice);
            matrixBlock4.setValue(0, 0, matrixBlock4.getValue(0, 0) + (Math.signum(slice.getValue(0, 0)) * Math.sqrt(slice.sumSq())));
            rightScalarOperator = rightScalarOperator.setConstant(Math.sqrt(matrixBlock4.sumSq()));
            MatrixBlock scalarOperations = matrixBlock4.scalarOperations((ScalarOperator) rightScalarOperator, (MatrixValue) new MatrixBlock());
            MatrixBlock matrixBlock5 = new MatrixBlock(i2, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixBlock5.copy(i4, i2 - 1, 0, 0, scalarOperations, true);
            MatrixBlock aggregateBinaryOperations = matrixBlock5.aggregateBinaryOperations(matrixBlock5, matrixBlock5.reorgOperations(reorgOperator, (MatrixValue) new MatrixBlock(), 0, 0, i2).scalarOperations((ScalarOperator) leftScalarOperator, (MatrixValue) new MatrixBlock()), matMultOperator);
            matrixBlock2 = matrixBlock2.binaryOperations(parseExtendedBinaryOperator, matrixBlock2.aggregateBinaryOperations(aggregateBinaryOperations, matrixBlock2, matMultOperator));
            matrixBlock3 = matrixBlock3.binaryOperations(parseExtendedBinaryOperator, matrixBlock3.aggregateBinaryOperations(matrixBlock3, aggregateBinaryOperations, matMultOperator));
        }
        return new MatrixBlock[]{matrixBlock3, matrixBlock2};
    }

    private static MatrixBlock[] computeEigenQR(MatrixBlock matrixBlock, int i) {
        return computeEigenQR(matrixBlock, 100, 1.0E-10d, i);
    }

    private static MatrixBlock[] computeEigenQR(MatrixBlock matrixBlock, int i, double d, int i2) {
        if (matrixBlock.getNumRows() != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("Eigen Decomposition (QR) can only be done on a square matrix. Input matrix is rectangular (rows=" + matrixBlock.getNumRows() + ", cols=" + matrixBlock.getNumColumns() + ")");
        }
        int i3 = matrixBlock.rlen;
        AggregateBinaryOperator matMultOperator = InstructionUtils.getMatMultOperator(i2);
        MatrixBlock matrixBlock2 = new MatrixBlock(i3, i3, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        for (int i4 = 0; i4 < i3; i4++) {
            matrixBlock2.setValue(i4, i4, 1.0d);
        }
        for (int i5 = 0; i5 < i; i5++) {
            MatrixBlock[] computeQR2 = computeQR2(matrixBlock, i2);
            matrixBlock2 = matrixBlock2.aggregateBinaryOperations(matrixBlock2, computeQR2[0], matMultOperator);
            matrixBlock = computeQR2[1].aggregateBinaryOperations(computeQR2[1], computeQR2[0], matMultOperator);
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        double[] dArr = new double[i3];
        for (int i6 = 0; i6 < i3; i6++) {
            dArr[i6] = denseBlockValues[(i6 * i3) + i6];
        }
        return sortEVs(dArr, matrixBlock2.getDenseBlockValues());
    }

    private static MatrixBlock computeHouseholder(MatrixBlock matrixBlock, int i) {
        int i2 = matrixBlock.rlen;
        MatrixBlock matrixBlock2 = new MatrixBlock(i2, i2, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        matrixBlock2.copy(matrixBlock);
        for (int i3 = 0; i3 < i2 - 2; i3++) {
            MatrixBlock slice = matrixBlock2.slice(0, i2 - 1, i3, i3);
            for (int i4 = 0; i4 <= i3; i4++) {
                slice.setValue(i4, 0, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
            double sqrt = Math.sqrt(slice.sumSq());
            double d = matrixBlock2.getDouble(i3 + 1, i3);
            if (d > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                sqrt *= -1.0d;
            }
            double sqrt2 = Math.sqrt(0.5d * ((sqrt * sqrt) - (d * sqrt)));
            MatrixBlock matrixBlock3 = new MatrixBlock(i2, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixBlock3.copy(slice);
            matrixBlock3.setValue(i3 + 1, 0, d - sqrt);
            MatrixBlock scalarOperations = matrixBlock3.scalarOperations((ScalarOperator) new RightScalarOperator(Divide.getDivideFnObject(), 2.0d * sqrt2, i), (MatrixValue) new MatrixBlock());
            MatrixBlock matrixBlock4 = new MatrixBlock(i2, i2, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            for (int i5 = 0; i5 < i2; i5++) {
                matrixBlock4.setValue(i5, i5, 1.0d);
            }
            ReorgOperator reorgOperator = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), i);
            AggregateBinaryOperator matMultOperator = InstructionUtils.getMatMultOperator(i);
            BinaryOperator parseExtendedBinaryOperator = InstructionUtils.parseExtendedBinaryOperator("+");
            BinaryOperator parseExtendedBinaryOperator2 = InstructionUtils.parseExtendedBinaryOperator(ProgramConverter.DASH);
            MatrixBlock reorgOperations = scalarOperations.reorgOperations(reorgOperator, (MatrixValue) new MatrixBlock(), 0, 0, i2);
            MatrixBlock binaryOperations = matrixBlock4.binaryOperations(parseExtendedBinaryOperator2, matrixBlock2.aggregateBinaryOperations(scalarOperations, reorgOperations.binaryOperations(parseExtendedBinaryOperator, reorgOperations), matMultOperator));
            matrixBlock2 = matrixBlock2.aggregateBinaryOperations(binaryOperations, matrixBlock2.aggregateBinaryOperations(matrixBlock2, binaryOperations, matMultOperator), matMultOperator);
        }
        return matrixBlock2;
    }

    private static MatrixBlock[] sortEVs(double[] dArr, double[][] dArr2) {
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            double d = dArr[i];
            for (int i3 = i + 1; i3 < length; i3++) {
                if (dArr[i3] < d) {
                    i2 = i3;
                    d = dArr[i3];
                }
            }
            if (i2 != i) {
                dArr[i2] = dArr[i];
                dArr[i] = d;
                for (int i4 = 0; i4 < length; i4++) {
                    double d2 = dArr2[i4][i];
                    dArr2[i4][i] = dArr2[i4][i2];
                    dArr2[i4][i2] = d2;
                }
            }
        }
        return new MatrixBlock[]{DataConverter.convertToMatrixBlock(dArr, true), DataConverter.convertToMatrixBlock(dArr2)};
    }

    private static MatrixBlock[] sortEVs(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            double d = dArr[i];
            for (int i3 = i + 1; i3 < length; i3++) {
                if (dArr[i3] < d) {
                    i2 = i3;
                    d = dArr[i3];
                }
            }
            if (i2 != i) {
                dArr[i2] = dArr[i];
                dArr[i] = d;
                for (int i4 = 0; i4 < length; i4++) {
                    double d2 = dArr2[(i4 * length) + i];
                    dArr2[(i4 * length) + i] = dArr2[(i4 * length) + i2];
                    dArr2[(i4 * length) + i2] = d2;
                }
            }
        }
        MatrixBlock convertToMatrixBlock = DataConverter.convertToMatrixBlock(dArr, true);
        MatrixBlock matrixBlock = new MatrixBlock(length, length, false);
        matrixBlock.init(dArr2, length, length);
        return new MatrixBlock[]{convertToMatrixBlock, matrixBlock};
    }
}
