package org.apache.sysds.runtime.compress.colgroup.functional;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection;
import org.apache.sysds.runtime.compress.utils.DblArray;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.class */
public interface LinearRegression {
    static double[] regressMatrixBlock(MatrixBlock matrixBlock, IColIndex iColIndex, boolean z) {
        int numColumns = z ? matrixBlock.getNumColumns() : matrixBlock.getNumRows();
        int size = iColIndex.size();
        if (numColumns <= 1) {
            throw new DMLCompressionException("At least 2 data points are required to fit a linear function.");
        }
        if (size < 1) {
            throw new DMLCompressionException("At least 1 column must be specified for compression.");
        }
        double[] dArr = new double[2 * size];
        double pow = (Math.pow(numColumns, 3.0d) - numColumns) / 12.0d;
        double d = (numColumns + 1) / 2.0d;
        double[] dArr2 = new double[size];
        double[] dArr3 = new double[size];
        if (size != 1) {
            ReaderColumnSelection createReader = ReaderColumnSelection.createReader(matrixBlock, iColIndex, z);
            while (true) {
                DblArray nextRow = createReader.nextRow();
                if (nextRow == null) {
                    break;
                }
                int currentRowIndex = createReader.getCurrentRowIndex() + 1;
                double[] data = nextRow.getData();
                for (int i = 0; i < size; i++) {
                    int i2 = i;
                    dArr2[i2] = dArr2[i2] + data[i];
                    int i3 = i;
                    dArr3[i3] = dArr3[i3] + (currentRowIndex * data[i]);
                }
            }
        } else if (z) {
            for (int i4 = 0; i4 < numColumns; i4++) {
                double value = matrixBlock.getValue(iColIndex.get(0), i4);
                dArr2[0] = dArr2[0] + value;
                dArr3[0] = dArr3[0] + ((i4 + 1) * value);
            }
        } else {
            for (int i5 = 0; i5 < numColumns; i5++) {
                double value2 = matrixBlock.getValue(i5, iColIndex.get(0));
                dArr2[0] = dArr2[0] + value2;
                dArr3[0] = dArr3[0] + ((i5 + 1) * value2);
            }
        }
        for (int i6 = 0; i6 < size; i6++) {
            dArr[size + i6] = (((-d) * dArr2[i6]) + dArr3[i6]) / pow;
            dArr[i6] = (dArr2[i6] / numColumns) - (dArr[size + i6] * d);
        }
        return dArr;
    }
}
