package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.class */
public class ColGroupSDCSingle extends AMorphingMMColGroup {
    private static final long serialVersionUID = 3883228464052204200L;
    protected AOffset _indexes;
    protected double[] _defaultTuple;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupSDCSingle(int i) {
        super(i);
    }

    private ColGroupSDCSingle(int[] iArr, int i, ADictionary aDictionary, double[] dArr, AOffset aOffset, int[] iArr2) {
        super(iArr, i, aDictionary, iArr2);
        this._indexes = aOffset;
        this._zeros = false;
        this._defaultTuple = dArr;
        if (this._indexes == null) {
            throw new NullPointerException("null indexes is invalid for SDCSingle");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static AColGroup create(int[] iArr, int i, ADictionary aDictionary, double[] dArr, AOffset aOffset, int[] iArr2) {
        boolean allZero = FORUtil.allZero(dArr);
        return (aDictionary == null && allZero) ? new ColGroupEmpty(iArr) : aDictionary == null ? new ColGroupSDCSingle(iArr, i, new Dictionary(new double[iArr.length]), dArr, aOffset, iArr2) : allZero ? ColGroupSDCSingleZeros.create(iArr, i, aDictionary, aOffset, iArr2) : new ColGroupSDCSingle(iArr, i, aDictionary, dArr, aOffset, iArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.SDC;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.SDCSingle;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getIdx(int i, int i2) {
        AIterator iterator = this._indexes.getIterator(i);
        return (iterator == null || iterator.value() != i) ? this._defaultTuple[i2] : this._dict.getValue(i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    public ADictionary getDictionary() {
        throw new NotImplementedException("Not implemented getting the dictionary out, and i think we should consider removing the option");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithDefault(this._defaultTuple);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithDefault(this._defaultTuple);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggProductRows() {
        throw new NotImplementedException("Should implement preAgg with extra cell");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithDefault(builtin, this._defaultTuple);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double computeMxx(double d, Builtin builtin) {
        double aggregate = this._dict.aggregate(d, builtin);
        for (int i = 0; i < this._defaultTuple.length; i++) {
            aggregate = builtin.execute(aggregate, this._defaultTuple[i]);
        }
        return aggregate;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeColMxx(double[] dArr, Builtin builtin) {
        this._dict.aggregateCols(dArr, builtin, this._colIndexes);
        for (int i = 0; i < this._colIndexes.length; i++) {
            dArr[this._colIndexes[i]] = builtin.execute(dArr[this._colIndexes[i]], this._defaultTuple[i]);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowSums(double[] dArr, int i, int i2, double[] dArr2) {
        computeRowSums(dArr, i, i2, dArr2, this._indexes, this._numRows);
    }

    protected static final void computeRowSums(double[] dArr, int i, int i2, double[] dArr2, AOffset aOffset, int i3) {
        int i4 = i;
        AIterator iterator = aOffset.getIterator(i);
        double d = dArr2[1];
        double d2 = dArr2[0];
        if (iterator != null && iterator.value() > i2) {
            aOffset.cacheIterator(iterator, i2);
        } else if (iterator != null && i2 >= aOffset.getOffsetToLast()) {
            int offsetToLast = aOffset.getOffsetToLast();
            while (true) {
                if (iterator.value() == i4) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] + d2;
                    if (iterator.value() >= offsetToLast) {
                        break;
                    } else {
                        iterator.next();
                    }
                } else {
                    int i6 = i4;
                    dArr[i6] = dArr[i6] + d;
                }
                i4++;
            }
            i4++;
        } else if (iterator != null) {
            while (i4 < i2) {
                if (iterator.value() == i4) {
                    int i7 = i4;
                    dArr[i7] = dArr[i7] + d2;
                } else {
                    int i8 = i4;
                    dArr[i8] = dArr[i8] + d;
                }
                i4++;
            }
            aOffset.cacheIterator(iterator, i2);
        }
        while (i4 < i2) {
            int i9 = i4;
            dArr[i9] = dArr[i9] + d;
            i4++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, double[] dArr2) {
        computeRowMxx(dArr, builtin, i, i2, this._indexes, this._numRows, dArr2[1], dArr2[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static final void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, AOffset aOffset, int i3, double d, double d2) {
        int i4 = i;
        AIterator iterator = aOffset.getIterator(i);
        if (iterator != null && iterator.value() > i2) {
            aOffset.cacheIterator(iterator, i2);
        } else if (iterator != null && i2 >= aOffset.getOffsetToLast()) {
            int offsetToLast = aOffset.getOffsetToLast();
            while (true) {
                if (iterator.value() == i4) {
                    dArr[i4] = builtin.execute(dArr[i4], d2);
                    if (iterator.value() >= offsetToLast) {
                        break;
                    } else {
                        iterator.next();
                    }
                } else {
                    dArr[i4] = builtin.execute(dArr[i4], d);
                }
                i4++;
            }
            i4++;
        } else if (iterator != null) {
            while (i4 < i2) {
                if (iterator.value() == i4) {
                    dArr[i4] = builtin.execute(dArr[i4], d2);
                    iterator.next();
                } else {
                    dArr[i4] = builtin.execute(dArr[i4], d);
                }
                i4++;
            }
            aOffset.cacheIterator(iterator, i2);
        }
        while (i4 < i2) {
            dArr[i4] = builtin.execute(dArr[i4], d);
            i4++;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeSum(double[] dArr, int i) {
        super.computeSum(dArr, i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._defaultTuple.length; i3++) {
            dArr[0] = dArr[0] + (this._defaultTuple[i3] * i2);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void computeColSums(double[] dArr, int i) {
        super.computeColSums(dArr, i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
            int i4 = this._colIndexes[i3];
            dArr[i4] = dArr[i4] + (this._defaultTuple[i3] * i2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeSumSq(double[] dArr, int i) {
        super.computeSumSq(dArr, i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
            dArr[0] = dArr[0] + (this._defaultTuple[i3] * this._defaultTuple[i3] * i2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeColSumsSq(double[] dArr, int i) {
        super.computeColSumsSq(dArr, i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
            int i4 = this._colIndexes[i3];
            dArr[i4] = dArr[i4] + (this._defaultTuple[i3] * this._defaultTuple[i3] * i2);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeProduct(double[] dArr, int i) {
        this._dict.productWithDefault(dArr, getCounts(), this._defaultTuple, this._numRows - getCounts()[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void computeColProduct(double[] dArr, int i) {
        super.computeColProduct(dArr, i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
            int i4 = this._colIndexes[i3];
            dArr[i4] = dArr[i4] * Math.pow(this._defaultTuple[i3], i2);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowProduct(double[] dArr, int i, int i2, double[] dArr2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    public int[] getCounts(int[] iArr) {
        iArr[0] = this._indexes.getSize();
        return iArr;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + this._indexes.getInMemorySize() + (8 * this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup scalarOperation(ScalarOperator scalarOperator) {
        double[] dArr = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; i++) {
            dArr[i] = scalarOperator.executeScalar(this._defaultTuple[i]);
        }
        return create(this._colIndexes, this._numRows, this._dict.applyScalarOp(scalarOperator), dArr, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup unaryOperation(UnaryOperator unaryOperator) {
        double[] dArr = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; i++) {
            dArr[i] = unaryOperator.fn.execute(this._defaultTuple[i]);
        }
        return create(this._colIndexes, this._numRows, this._dict.applyUnaryOp(unaryOperator), dArr, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpLeft(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        double[] dArr2 = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; i++) {
            dArr2[i] = binaryOperator.fn.execute(dArr[this._colIndexes[i]], this._defaultTuple[i]);
        }
        return create(this._colIndexes, this._numRows, this._dict.binOpLeft(binaryOperator, dArr, this._colIndexes), dArr2, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpRight(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        double[] dArr2 = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; i++) {
            dArr2[i] = binaryOperator.fn.execute(this._defaultTuple[i], dArr[this._colIndexes[i]]);
        }
        return new ColGroupSDCSingle(this._colIndexes, this._numRows, this._dict.binOpRight(binaryOperator, dArr, this._colIndexes), dArr2, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._indexes.write(dataOutput);
        for (double d : this._defaultTuple) {
            dataOutput.writeDouble(d);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this._indexes = OffsetFactory.readIn(dataInput);
        this._defaultTuple = new double[this._colIndexes.length];
        for (int i = 0; i < this._colIndexes.length; i++) {
            this._defaultTuple[i] = dataInput.readDouble();
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        return super.getExactSizeOnDisk() + this._indexes.getExactSizeOnDisk() + (8 * this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup replace(double d, double d2) {
        ADictionary replace = this._dict.replace(d, d2, this._colIndexes.length);
        double[] dArr = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; i++) {
            dArr[i] = this._defaultTuple[i] == d ? d2 : this._defaultTuple[i];
        }
        return create(this._colIndexes, this._numRows, replace, dArr, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup
    public AColGroup extractCommon(double[] dArr) {
        for (int i = 0; i < this._colIndexes.length; i++) {
            int i2 = this._colIndexes[i];
            dArr[i2] = dArr[i2] + this._defaultTuple[i];
        }
        return ColGroupSDCSingleZeros.create(this._colIndexes, this._numRows, this._dict.subtractTuple(this._defaultTuple), this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getNumberNonZeros(int i) {
        long numberNonZeros = super.getNumberNonZeros(i);
        int i2 = this._numRows - getCounts()[0];
        for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
            numberNonZeros += this._defaultTuple[i3] != DataExpression.DEFAULT_DELIM_FILL_VALUE ? i2 : 0L;
        }
        return numberNonZeros;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public CM_COV_Object centralMoment(CMOperator cMOperator, int i) {
        CM_COV_Object centralMoment = super.centralMoment(cMOperator, i);
        cMOperator.fn.execute(centralMoment, this._defaultTuple[0], this._numRows - getCounts()[0]);
        return centralMoment;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup rexpandCols(int i, boolean z, boolean z2, int i2) {
        ADictionary rexpandCols = this._dict.rexpandCols(i, z, z2, this._colIndexes.length);
        double d = this._defaultTuple[0];
        if (rexpandCols == null) {
            if (d <= DataExpression.DEFAULT_DELIM_FILL_VALUE || d > i) {
                return ColGroupEmpty.create(i);
            }
            double[] dArr = new double[i];
            dArr[((int) this._defaultTuple[0]) - 1] = 1.0d;
            return new ColGroupSDCSingle(Util.genColsIndices(i), i2, new Dictionary(new double[i]), dArr, this._indexes, null);
        }
        if (d <= DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            if (z) {
                return ColGroupSDCSingleZeros.create(Util.genColsIndices(i), i2, rexpandCols, this._indexes, getCachedCounts());
            }
            throw new DMLRuntimeException("Invalid content of zero in rexpand");
        }
        if (d > i) {
            return ColGroupSDCSingleZeros.create(Util.genColsIndices(i), i2, rexpandCols, this._indexes, getCachedCounts());
        }
        double[] dArr2 = new double[i];
        dArr2[((int) this._defaultTuple[0]) - 1] = 1.0d;
        return new ColGroupSDCSingle(Util.genColsIndices(i), i2, rexpandCols, dArr2, this._indexes, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getCost(ComputationCostEstimator computationCostEstimator, int i) {
        int numValues = getNumValues();
        return computationCostEstimator.getCost(i, getCounts()[0], getNumCols(), numValues, this._dict.getSparsity());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup sliceMultiColumns(int i, int i2, int[] iArr) {
        ColGroupSDCSingle colGroupSDCSingle = (ColGroupSDCSingle) super.sliceMultiColumns(i, i2, iArr);
        colGroupSDCSingle._defaultTuple = new double[i2 - i];
        int i3 = i;
        int i4 = 0;
        while (i3 < i2) {
            colGroupSDCSingle._defaultTuple[i4] = this._defaultTuple[i3];
            i3++;
            i4++;
        }
        return colGroupSDCSingle;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup sliceSingleColumn(int i) {
        ColGroupSDCSingle colGroupSDCSingle = (ColGroupSDCSingle) super.sliceSingleColumn(i);
        colGroupSDCSingle._defaultTuple = new double[1];
        colGroupSDCSingle._defaultTuple[0] = this._defaultTuple[i];
        return colGroupSDCSingle;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public boolean containsValue(double d) {
        if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE && this._zeros) {
            return true;
        }
        boolean containsValue = this._dict.containsValue(d);
        if (containsValue) {
            return containsValue;
        }
        for (double d2 : this._defaultTuple) {
            if (d2 == d) {
                return true;
            }
        }
        return false;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        return super.toString() + String.format("\n%15s", "Default: ") + Arrays.toString(this._defaultTuple) + String.format("\n%15s", "Indexes: ") + this._indexes.toString();
    }
}
