package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.compress.utils.Bitmap;
import org.apache.sysds.runtime.compress.utils.BitmapLossy;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupValue.class */
public abstract class ColGroupValue extends ColGroup {
    private static final long serialVersionUID = 3786247536054353658L;
    private static ThreadLocal<Pair<int[], double[]>> memPool = new ThreadLocal<Pair<int[], double[]>>() { // from class: org.apache.sysds.runtime.compress.colgroup.ColGroupValue.1
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public Pair<int[], double[]> initialValue() {
            return new Pair<>();
        }
    };
    protected ADictionary _dict;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue(int[] iArr, int i, ABitmap aBitmap, CompressionSettings compressionSettings) {
        super(iArr, i);
        this._lossy = false;
        this._zeros = aBitmap.containsZero();
        if (compressionSettings.sortValuesByLength && i > 65536) {
            aBitmap.sortValuesByFrequency();
        }
        switch (aBitmap.getType()) {
            case Full:
                this._dict = new Dictionary(((Bitmap) aBitmap).getValues());
                return;
            case Lossy:
                this._dict = new QDictionary((BitmapLossy) aBitmap);
                this._lossy = true;
                return;
            default:
                return;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue(int[] iArr, int i, ADictionary aDictionary) {
        super(iArr, i);
        this._dict = aDictionary;
    }

    public int getNumValues() {
        return this._dict.getNumberOfValues(this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public double[] getValues() {
        return this._dict.getValues();
    }

    public byte[] getByteValues() {
        return ((QDictionary) this._dict).getValuesByte();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public MatrixBlock getValuesAsBlock() {
        double[] values = getValues();
        int length = values.length;
        MatrixBlock matrixBlock = new MatrixBlock(this._zeros ? length + 1 : length, 1, false);
        for (int i = 0; i < length; i++) {
            matrixBlock.quickSetValue(i, 0, values[i]);
        }
        return matrixBlock;
    }

    public final int[] getCounts() {
        return getCounts(this._zeros ? allocIVector(getNumValues() + 1, true) : allocIVector(getNumValues(), true));
    }

    public final int[] getCounts(int i, int i2) {
        return getCounts(i, i2, this._zeros ? allocIVector(getNumValues() + 1, true) : allocIVector(getNumValues(), true));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public boolean getIfCountsType() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int containsAllZeroValue() {
        return this._dict.hasZeroTuple(this._colIndexes.length);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double sumValues(int i, double[] dArr, double[] dArr2) {
        int numCols = getNumCols();
        int i2 = i * numCols;
        double d = 0.0d;
        for (int i3 = 0; i3 < numCols; i3++) {
            d += dArr2[i2 + i3] * dArr[this._colIndexes[i3]];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double[] preaggValues(int i, double[] dArr, double[] dArr2) {
        return preaggValues(i, dArr, false, dArr2);
    }

    protected final double[] preaggValues(int i, double[] dArr, boolean z, double[] dArr2) {
        double[] allocDVector = z ? new double[i + 1] : allocDVector(i + 1, false);
        if (this._colIndexes.length == 1) {
            for (int i2 = 0; i2 < i; i2++) {
                allocDVector[i2] = dArr2[i2] * dArr[this._colIndexes[0]];
            }
        } else {
            for (int i3 = 0; i3 < i; i3++) {
                allocDVector[i3] = sumValues(i3, dArr, dArr2);
            }
        }
        return allocDVector;
    }

    protected void computeMxx(double[] dArr, Builtin builtin) {
        if (this._zeros) {
            dArr[0] = builtin.execute(dArr[0], DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        dArr[0] = this._dict.aggregate(dArr[0], builtin);
    }

    protected void computeColMxx(double[] dArr, Builtin builtin) {
        if (this._zeros) {
            for (int i = 0; i < this._colIndexes.length; i++) {
                dArr[this._colIndexes[i]] = builtin.execute(dArr[this._colIndexes[i]], DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
        }
        this._dict.aggregateCols(dArr, builtin, this._colIndexes);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ADictionary applyScalarOp(ScalarOperator scalarOperator) {
        return this._dict.mo425clone().apply(scalarOperator);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ADictionary applyScalarOp(ScalarOperator scalarOperator, double d, int i) {
        return this._dict.applyScalarOp(scalarOperator, d, i);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, double[] dArr) {
        unaryAggregateOperations(aggregateUnaryOperator, dArr, 0, this._numRows);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, double[] dArr, int i, int i2) {
        if ((aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlus) || (aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlusSq) || (aggregateUnaryOperator.aggOp.increOp.fn instanceof Mean)) {
            KahanFunction kahanPlusFnObject = ((aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlus) || (aggregateUnaryOperator.aggOp.increOp.fn instanceof Mean)) ? KahanPlus.getKahanPlusFnObject() : KahanPlusSq.getKahanPlusSqFnObject();
            boolean z = aggregateUnaryOperator.aggOp.increOp.fn instanceof Mean;
            if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                computeSum(dArr, kahanPlusFnObject);
                return;
            } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                computeRowSums(dArr, kahanPlusFnObject, i, i2, z);
                return;
            } else {
                if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColSums(dArr, kahanPlusFnObject);
                    return;
                }
                return;
            }
        }
        if (!(aggregateUnaryOperator.aggOp.increOp.fn instanceof Builtin) || (((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() != Builtin.BuiltinCode.MAX && ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() != Builtin.BuiltinCode.MIN)) {
            throw new DMLScriptException("Unknown UnaryAggregate operator on CompressedMatrixBlock");
        }
        Builtin builtin = (Builtin) aggregateUnaryOperator.aggOp.increOp.fn;
        if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
            computeMxx(dArr, builtin);
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
            computeRowMxx(dArr, builtin, i, i2);
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
            computeColMxx(dArr, builtin);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setandExecute(double[] dArr, KahanObject kahanObject, KahanPlus kahanPlus, double d, int i) {
        kahanObject.set(dArr[i], dArr[i + 1]);
        kahanPlus.execute2(kahanObject, d);
        dArr[i] = kahanObject._sum;
        dArr[i + 1] = kahanObject._correction;
    }

    public static void setupThreadLocalMemory(int i) {
        Pair<int[], double[]> pair = new Pair<>();
        pair.setKey(new int[i]);
        pair.setValue(new double[i]);
        memPool.set(pair);
    }

    public static void cleanupThreadLocalMemory() {
        memPool.remove();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[] allocDVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair.getValue() == null) {
            return new double[i];
        }
        double[] value = pair.getValue();
        if (z) {
            Arrays.fill(value, 0, i, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        return value;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int[] allocIVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair.getKey() == null) {
            return new int[i];
        }
        int[] key = pair.getKey();
        if (z) {
            Arrays.fill(key, 0, i, 0);
        }
        return key;
    }

    public String toString() {
        return super.toString() + String.format("\n%15s%5d ", "Columns:", Integer.valueOf(this._colIndexes.length)) + Arrays.toString(this._colIndexes) + String.format("\n%15s%5d ", "Values:", Integer.valueOf(this._dict.getValues().length)) + Arrays.toString(this._dict.getValues());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public boolean isLossy() {
        return this._lossy;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void readFields(DataInput dataInput) throws IOException {
        this._numRows = dataInput.readInt();
        int readInt = dataInput.readInt();
        this._zeros = dataInput.readBoolean();
        this._lossy = dataInput.readBoolean();
        this._colIndexes = new int[readInt];
        for (int i = 0; i < readInt; i++) {
            this._colIndexes[i] = dataInput.readInt();
        }
        this._dict = ADictionary.read(dataInput, this._lossy);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void write(DataOutput dataOutput) throws IOException {
        int numCols = getNumCols();
        dataOutput.writeInt(this._numRows);
        dataOutput.writeInt(numCols);
        dataOutput.writeBoolean(this._zeros);
        dataOutput.writeBoolean(this._lossy);
        for (int i = 0; i < this._colIndexes.length; i++) {
            dataOutput.writeInt(this._colIndexes[i]);
        }
        this._dict.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public long getExactSizeOnDisk() {
        return 0 + 4 + 4 + 1 + 1 + (4 * this._colIndexes.length) + this._dict.getExactSizeOnDisk();
    }

    public abstract int[] getCounts(int[] iArr);

    public abstract int[] getCounts(int i, int i2, int[] iArr);

    protected abstract void computeSum(double[] dArr, KahanFunction kahanFunction);

    protected abstract void computeRowSums(double[] dArr, KahanFunction kahanFunction, int i, int i2, boolean z);

    protected abstract void computeColSums(double[] dArr, KahanFunction kahanFunction);

    protected abstract void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2);
}
