package com.zavtech.morpheus.reference.regress;

import com.zavtech.morpheus.array.Array;
import com.zavtech.morpheus.frame.DataFrame;
import com.zavtech.morpheus.frame.DataFrameException;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/zavtech/morpheus/reference/regress/XDataFrame_WLS.class */
class XDataFrame_WLS<R, C> extends XDataFrameLeastSquares<R, C> {
    private Array<Double> weights;

    /* JADX INFO: Access modifiers changed from: package-private */
    public XDataFrame_WLS(DataFrame<R, C> dataFrame, C c, List<C> list, boolean z, Array<Double> array) {
        super("WLS", dataFrame, c, list, z);
        this.weights = array;
    }

    @Override // com.zavtech.morpheus.reference.regress.XDataFrameLeastSquares
    public void compute() {
        try {
            RealVector createY = createY();
            RealMatrix createX = createX();
            DiagonalMatrix diagonalMatrix = new DiagonalMatrix(this.weights.stream().doubles().map(Math::sqrt).toArray());
            compute(diagonalMatrix.operate(createY), diagonalMatrix.multiply(createX));
        } catch (DataFrameException e) {
            throw e;
        } catch (Exception e2) {
            throw new DataFrameException(String.format("WLS regression failed for %s on %s", getRegressand(), Arrays.toString(getRegressors().toArray())), e2);
        }
    }

    @Override // com.zavtech.morpheus.reference.regress.XDataFrameLeastSquares
    protected double computeTSS(RealVector realVector) {
        if (!hasIntercept()) {
            return realVector.dotProduct(realVector);
        }
        C regressand = getRegressand();
        double doubleValue = this.weights.stats().sum().doubleValue();
        Array<Double> of = Array.of(frame().col(regressand).toDoubleStream().toArray());
        double doubleValue2 = of.mapToDoubles(arrayValue -> {
            return arrayValue.getDouble() * this.weights.getDouble(arrayValue.index());
        }).stats().sum().doubleValue() / doubleValue;
        return of.mapToDoubles(arrayValue2 -> {
            return this.weights.getDouble(arrayValue2.index()) * Math.pow(arrayValue2.getDouble() - doubleValue2, 2.0d);
        }).stats().sum().doubleValue();
    }
}
