package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.class */
public class ColGroupDDCFOR extends AMorphingMMColGroup implements IFrameOfReferenceGroup {
    private static final long serialVersionUID = -5769772089913918987L;
    protected final AMapToData _data;
    protected final double[] _reference;

    private ColGroupDDCFOR(IColIndex iColIndex, IDictionary iDictionary, double[] dArr, AMapToData aMapToData, int[] iArr) {
        super(iColIndex, iDictionary, iArr);
        this._data = aMapToData;
        this._reference = dArr;
    }

    public static AColGroup create(IColIndex iColIndex, IDictionary iDictionary, AMapToData aMapToData, int[] iArr, double[] dArr) {
        boolean allZero = ColGroupUtils.allZero(dArr);
        return (iDictionary == null && allZero) ? new ColGroupEmpty(iColIndex) : iDictionary == null ? ColGroupConst.create(iColIndex, dArr) : aMapToData.getUnique() == 1 ? ColGroupConst.create(iColIndex, iDictionary.binOpRight(new BinaryOperator(Plus.getPlusFnObject()), dArr)) : allZero ? ColGroupDDC.create(iColIndex, iDictionary, aMapToData, iArr) : new ColGroupDDCFOR(iColIndex, iDictionary, dArr, aMapToData, iArr);
    }

    public static AColGroup sparsifyFOR(ColGroupDDC colGroupDDC) {
        MatrixBlockDictionary mBDict = colGroupDDC._dict.getMBDict(colGroupDDC.getColIndices().size());
        if (mBDict == null) {
            throw new NotImplementedException("The dictionary was empty... highly unlikely");
        }
        double[] extractMostCommonValueInColumns = ColGroupUtils.extractMostCommonValueInColumns(mBDict.getMatrixBlock());
        if (extractMostCommonValueInColumns == null) {
            return colGroupDDC;
        }
        return create(colGroupDDC.getColIndices(), mBDict.binOpRight(new BinaryOperator(Minus.getMinusFnObject()), extractMostCommonValueInColumns), colGroupDDC._data, colGroupDDC.getCachedCounts(), extractMostCommonValueInColumns);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.DDCFOR;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getIdx(int i, int i2) {
        return this._dict.getValue(this._data.getIndex(i), i2, this._colIndexes.size()) + this._reference[i2];
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowSums(double[] dArr, int i, int i2, double[] dArr2) {
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + dArr2[this._data.getIndex(i3)];
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, double[] dArr2) {
        for (int i3 = i; i3 < i2; i3++) {
            dArr[i3] = builtin.execute(dArr[i3], dArr2[this._data.getIndex(i3)]);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    public int[] getCounts(int[] iArr) {
        return this._data.getCounts(iArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.DDCFOR;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + this._data.getInMemorySize() + (8 * this._colIndexes.size());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup scalarOperation(ScalarOperator scalarOperator) {
        double[] dArr = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; i++) {
            dArr[i] = scalarOperator.executeScalar(this._reference[i]);
        }
        if ((scalarOperator.fn instanceof Plus) || (scalarOperator.fn instanceof Minus)) {
            return create(this._colIndexes, this._dict, this._data, getCachedCounts(), dArr);
        }
        if ((scalarOperator.fn instanceof Multiply) || (scalarOperator.fn instanceof Divide)) {
            return create(this._colIndexes, this._dict.applyScalarOp(scalarOperator), this._data, getCachedCounts(), dArr);
        }
        return create(this._colIndexes, this._dict.applyScalarOpWithReference(scalarOperator, this._reference, dArr), this._data, getCachedCounts(), dArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup unaryOperation(UnaryOperator unaryOperator) {
        double[] unaryOperator2 = ColGroupUtils.unaryOperator(unaryOperator, this._reference);
        return create(this._colIndexes, this._dict.applyUnaryOpWithReference(unaryOperator, this._reference, unaryOperator2), this._data, getCachedCounts(), unaryOperator2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpLeft(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        double[] dArr2 = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; i++) {
            dArr2[i] = binaryOperator.fn.execute(dArr[this._colIndexes.get(i)], this._reference[i]);
        }
        if ((binaryOperator.fn instanceof Plus) || (binaryOperator.fn instanceof Minus)) {
            return create(this._colIndexes, this._dict, this._data, getCachedCounts(), dArr2);
        }
        if ((binaryOperator.fn instanceof Multiply) || (binaryOperator.fn instanceof Divide)) {
            return create(this._colIndexes, this._dict.binOpLeft(binaryOperator, dArr, this._colIndexes), this._data, getCachedCounts(), dArr2);
        }
        return create(this._colIndexes, this._dict.binOpLeftWithReference(binaryOperator, dArr, this._colIndexes, this._reference, dArr2), this._data, getCachedCounts(), dArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpRight(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        double[] dArr2 = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; i++) {
            dArr2[i] = binaryOperator.fn.execute(this._reference[i], dArr[this._colIndexes.get(i)]);
        }
        if ((binaryOperator.fn instanceof Plus) || (binaryOperator.fn instanceof Minus)) {
            return create(this._colIndexes, this._dict, this._data, getCachedCounts(), dArr2);
        }
        if ((binaryOperator.fn instanceof Multiply) || (binaryOperator.fn instanceof Divide)) {
            return create(this._colIndexes, this._dict.binOpRight(binaryOperator, dArr, this._colIndexes), this._data, getCachedCounts(), dArr2);
        }
        return create(this._colIndexes, this._dict.binOpRightWithReference(binaryOperator, dArr, this._colIndexes, this._reference, dArr2), this._data, getCachedCounts(), dArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._data.write(dataOutput);
        for (double d : this._reference) {
            dataOutput.writeDouble(d);
        }
    }

    public static ColGroupDDCFOR read(DataInput dataInput) throws IOException {
        IColIndex read = ColIndexFactory.read(dataInput);
        return new ColGroupDDCFOR(read, DictionaryFactory.read(dataInput), ColGroupIO.readDoubleArray(read.size(), dataInput), MapToFactory.readIn(dataInput), null);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        return super.getExactSizeOnDisk() + this._data.getExactSizeOnDisk() + (8 * this._colIndexes.size());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getCost(ComputationCostEstimator computationCostEstimator, int i) {
        return computationCostEstimator.getCost(i, i, getNumCols(), getNumValues(), this._dict.getSparsity());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup replace(double d, double d2) {
        IDictionary replaceWithReference = this._dict.replaceWithReference(d, d2, this._reference);
        boolean z = false;
        double[] dArr = this._reference;
        int length = dArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (d == dArr[i]) {
                z = true;
                break;
            }
            i++;
        }
        if (!z) {
            return create(this._colIndexes, replaceWithReference, this._data, getCachedCounts(), this._reference);
        }
        double[] dArr2 = new double[this._reference.length];
        for (int i2 = 0; i2 < this._reference.length; i2++) {
            if (Util.eq(d, this._reference[i2])) {
                dArr2[i2] = d2;
            } else {
                dArr2[i2] = this._reference[i2];
            }
        }
        return create(this._colIndexes, replaceWithReference, this._data, getCachedCounts(), dArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double computeMxx(double d, Builtin builtin) {
        return this._dict.aggregateWithReference(d, builtin, this._reference, false);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeColMxx(double[] dArr, Builtin builtin) {
        this._dict.aggregateColsWithReference(dArr, builtin, this._colIndexes, this._reference, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeSum(double[] dArr, int i) {
        super.computeSum(dArr, i);
        dArr[0] = dArr[0] + (ColGroupUtils.refSum(this._reference) * i);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void computeColSums(double[] dArr, int i) {
        super.computeColSums(dArr, i);
        for (int i2 = 0; i2 < this._colIndexes.size(); i2++) {
            int i3 = this._colIndexes.get(i2);
            dArr[i3] = dArr[i3] + (this._reference[i2] * i);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeSumSq(double[] dArr, int i) {
        dArr[0] = dArr[0] + this._dict.sumSqWithReference(getCounts(), this._reference);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeColSumsSq(double[] dArr, int i) {
        this._dict.colSumSqWithReference(dArr, getCounts(), this._colIndexes, this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithReference(this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithReference(this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggProductRows() {
        return this._dict.productAllRowsToDoubleWithReference(this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithReference(builtin, this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeProduct(double[] dArr, int i) {
        this._dict.productWithReference(dArr, getCounts(), this._reference, 0);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowProduct(double[] dArr, int i, int i2, double[] dArr2) {
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] * dArr2[this._data.getIndex(i3)];
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeColProduct(double[] dArr, int i) {
        this._dict.colProductWithReference(dArr, getCounts(), this._colIndexes, this._reference);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup sliceMultiColumns(int i, int i2, IColIndex iColIndex) {
        IDictionary sliceOutColumnRange = this._dict.sliceOutColumnRange(i, i2, this._colIndexes.size());
        double[] dArr = new double[i2 - i];
        int i3 = i;
        int i4 = 0;
        while (i3 < i2) {
            dArr[i4] = this._reference[i3];
            i3++;
            i4++;
        }
        return create(iColIndex, sliceOutColumnRange, this._data, getCounts(), dArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup sliceSingleColumn(int i) {
        IColIndex create = ColIndexFactory.create(1);
        return this._colIndexes.size() == 1 ? create(create, this._dict, this._data, getCounts(), this._reference) : create(create, this._dict.sliceOutColumnRange(i, i + 1, this._colIndexes.size()), this._data, getCounts(), new double[]{this._reference[i]});
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public boolean containsValue(double d) {
        return (Double.isNaN(d) || Double.isInfinite(d)) ? ColGroupUtils.containsInfOrNan(d, this._reference) || this._dict.containsValue(d) : this._dict.containsValueWithReference(d, this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getNumberNonZeros(int i) {
        return this._colIndexes.size() * i;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup, org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup
    public AColGroup extractCommon(double[] dArr) {
        for (int i = 0; i < this._colIndexes.size(); i++) {
            int i2 = this._colIndexes.get(i);
            dArr[i2] = dArr[i2] + this._reference[i];
        }
        return ColGroupDDC.create(this._colIndexes, this._dict, this._data, getCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup rexpandCols(int i, boolean z, boolean z2, int i2) {
        int i3 = (int) this._reference[0];
        IDictionary rexpandColsWithReference = this._dict.rexpandColsWithReference(i, z, z2, i3);
        if (rexpandColsWithReference == null) {
            if (i3 <= 0 || i3 > i) {
                return ColGroupEmpty.create(i);
            }
            double[] dArr = new double[i];
            dArr[i3 - 1] = 1.0d;
            return ColGroupConst.create(dArr);
        }
        IColIndex create = ColIndexFactory.create(i);
        if (i3 <= 0) {
            if (z) {
                return ColGroupDDC.create(create, rexpandColsWithReference, this._data, getCachedCounts());
            }
            throw new DMLRuntimeException("Invalid content of zero in rexpand");
        }
        if (i3 > i) {
            return ColGroupDDC.create(create, rexpandColsWithReference, this._data, getCachedCounts());
        }
        double[] dArr2 = new double[i];
        dArr2[i3 - 1] = 1.0d;
        return create(create, rexpandColsWithReference, this._data, getCachedCounts(), dArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public CM_COV_Object centralMoment(CMOperator cMOperator, int i) {
        return this._dict.centralMomentWithReference(cMOperator.fn, getCounts(), this._reference[0], i);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup
    public double[] getCommon() {
        return this._reference;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup
    protected AColGroup allocateRightMultiplicationCommon(double[] dArr, IColIndex iColIndex, IDictionary iDictionary) {
        return create(iColIndex, iDictionary, this._data, getCachedCounts(), dArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup sliceRows(int i, int i2) {
        return new ColGroupDDCFOR(this._colIndexes, this._dict, this._reference, this._data.slice(i, i2), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup
    public AColGroup copyAndSet(IColIndex iColIndex, IDictionary iDictionary) {
        return create(iColIndex, iDictionary, this._data, getCachedCounts(), this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup append(AColGroup aColGroup) {
        if (!(aColGroup instanceof ColGroupDDCFOR) || !aColGroup.getColIndices().equals(this._colIndexes)) {
            return null;
        }
        ColGroupDDCFOR colGroupDDCFOR = (ColGroupDDCFOR) aColGroup;
        if (!Arrays.equals(this._reference, colGroupDDCFOR._reference) || !colGroupDDCFOR._dict.equals(this._dict)) {
            return null;
        }
        return create(this._colIndexes, this._dict, this._data.append(colGroupDDCFOR._data), null, this._reference);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup appendNInternal(AColGroup[] aColGroupArr, int i, int i2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public ICLAScheme getCompressionScheme() {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup recompress() {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public CompressedSizeInfoColGroup getCompressionInfo(int i) {
        IEncode encoding = getEncoding();
        return new CompressedSizeInfoColGroup(this._colIndexes, new EstimationFactors(getNumValues(), this._data.size(), this._data.size(), this._dict.getSparsity()), estimateInMemorySize(), getCompType(), encoding);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public IEncode getEncoding() {
        return EncodingFactory.create(this._data);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public boolean sameIndexStructure(AColGroupCompressed aColGroupCompressed) {
        return (aColGroupCompressed instanceof ColGroupDDCFOR) && ((ColGroupDDCFOR) aColGroupCompressed)._data == this._data;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected AColGroup fixColIndexes(IColIndex iColIndex, int[] iArr) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        return super.toString() + String.format("\n%15s ", "Data: ") + this._data + String.format("\n%15s", "Reference:") + Arrays.toString(this._reference);
    }
}
