package org.apache.sysds.runtime.codegen;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFactory;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofRowwise.class */
public abstract class SpoofRowwise extends SpoofOperator {
    private static final long serialVersionUID = 6242910797139642998L;
    protected final RowType _type;
    protected final long _constDim2;
    protected final boolean _tB1;
    protected final int _reqVectMem;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofRowwise$OutputDimensions.class */
    public class OutputDimensions {
        public final int rows;
        public final int cols;

        /* JADX INFO: Access modifiers changed from: package-private */
        public OutputDimensions(int i, int i2, int i3) {
            switch (r5._type) {
                case NO_AGG:
                    this.rows = i;
                    this.cols = i2;
                    return;
                case NO_AGG_B1:
                    this.rows = i;
                    this.cols = i3;
                    return;
                case NO_AGG_CONST:
                    this.rows = i;
                    this.cols = (int) SpoofRowwise.this._constDim2;
                    return;
                case FULL_AGG:
                    this.rows = 1;
                    this.cols = 1;
                    return;
                case ROW_AGG:
                    this.rows = i;
                    this.cols = 1;
                    return;
                case COL_AGG:
                    this.rows = 1;
                    this.cols = i2;
                    return;
                case COL_AGG_T:
                    this.rows = i2;
                    this.cols = 1;
                    return;
                case COL_AGG_B1:
                    this.rows = i3;
                    this.cols = i2;
                    return;
                case COL_AGG_B1_T:
                    this.rows = i2;
                    this.cols = i3;
                    return;
                case COL_AGG_B1R:
                    this.rows = 1;
                    this.cols = i3;
                    return;
                case COL_AGG_CONST:
                    this.rows = 1;
                    this.cols = (int) SpoofRowwise.this._constDim2;
                    return;
                default:
                    this.rows = 0;
                    this.cols = 0;
                    return;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofRowwise$ParColAggTask.class */
    public class ParColAggTask implements Callable<DenseBlock> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _outLen;
        private final int _rl;
        private final int _ru;

        protected ParColAggTask(MatrixBlock matrixBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, int i, int i2, int i3, int i4, int i5) {
            this._a = matrixBlock;
            this._b = sideInputArr;
            this._scalars = dArr;
            this._clen = i;
            this._clen2 = i2;
            this._outLen = i3;
            this._rl = i4;
            this._ru = i5;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public DenseBlock call() {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            DenseBlock createDenseBlock = DenseBlockFactory.createDenseBlock(1, this._outLen);
            if (this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, createDenseBlock, this._clen, this._rl, this._ru, 0L);
            } else {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, createDenseBlock, this._clen, this._rl, this._ru, 0L);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return createDenseBlock;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofRowwise$ParExecTask.class */
    public class ParExecTask implements Callable<Long> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final MatrixBlock _c;
        private final double[] _scalars;
        private final int _clen;
        private final int _clen2;
        private final int _rl;
        private final int _ru;

        protected ParExecTask(MatrixBlock matrixBlock, SpoofOperator.SideInput[] sideInputArr, MatrixBlock matrixBlock2, double[] dArr, int i, int i2, int i3, int i4) {
            this._a = matrixBlock;
            this._b = sideInputArr;
            this._c = matrixBlock2;
            this._scalars = dArr;
            this._clen = i;
            this._clen2 = i2;
            this._rl = i3;
            this._ru = i4;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() {
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.setupThreadLocalMemory(SpoofRowwise.this._reqVectMem, this._clen, this._clen2);
            }
            if (this._a.isInSparseFormat()) {
                SpoofRowwise.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru, 0L);
            } else {
                SpoofRowwise.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, this._c.getDenseBlock(), this._clen, this._rl, this._ru, 0L);
            }
            if (SpoofRowwise.this._reqVectMem > 0) {
                LibSpoofPrimitives.cleanupThreadLocalMemory();
            }
            return Long.valueOf(this._c.recomputeNonZeros(this._rl, this._ru - 1, 0, this._c.getNumColumns() - 1));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofRowwise$RowType.class */
    public enum RowType {
        NO_AGG(0),
        NO_AGG_B1(1),
        NO_AGG_CONST(2),
        FULL_AGG(3),
        ROW_AGG(4),
        COL_AGG(5),
        COL_AGG_T(6),
        COL_AGG_B1(7),
        COL_AGG_B1_T(8),
        COL_AGG_B1R(9),
        COL_AGG_CONST(10);

        private final int value;
        private static final HashMap<Integer, RowType> map = new HashMap<>();

        RowType(int i) {
            this.value = i;
        }

        public static RowType valueOf(int i) {
            return map.get(Integer.valueOf(i));
        }

        public int getValue() {
            return this.value;
        }

        public boolean isColumnAgg() {
            return this == COL_AGG || this == COL_AGG_T || this == COL_AGG_B1 || this == COL_AGG_B1_T || this == COL_AGG_B1R || this == COL_AGG_CONST;
        }

        public boolean isRowTypeB1() {
            return this == NO_AGG_B1 || this == COL_AGG_B1 || this == COL_AGG_B1_T || this == COL_AGG_B1R;
        }

        public boolean isRowTypeB1ColumnAgg() {
            return this == COL_AGG_B1 || this == COL_AGG_B1_T;
        }

        public boolean isConstDim2(long j) {
            return this == NO_AGG_CONST || this == COL_AGG_CONST || (j >= 0 && isRowTypeB1());
        }

        static {
            for (RowType rowType : values()) {
                map.put(Integer.valueOf(rowType.value), rowType);
            }
        }
    }

    public SpoofRowwise(RowType rowType, long j, boolean z, int i) {
        this._type = rowType;
        this._constDim2 = j;
        this._tB1 = z;
        this._reqVectMem = i;
    }

    public RowType getRowType() {
        return this._type;
    }

    public long getConstDim2() {
        return this._constDim2;
    }

    public int getNumIntermediates() {
        return this._reqVectMem;
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public String getSpoofType() {
        return "RA" + getClass().getName().split("\\.")[1];
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public SpoofCUDAOperator createCUDAInstrcution(Integer num, SpoofCUDAOperator.PrecisionProxy precisionProxy) {
        return new SpoofCUDARowwise(this._type, this._constDim2, this._tB1, this._reqVectMem, num.intValue(), precisionProxy);
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public ScalarObject execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, int i) {
        return new DoubleObject((i > 1 ? execute(arrayList, arrayList2, new MatrixBlock(1, 1, false), i) : execute(arrayList, arrayList2, new MatrixBlock(1, 1, false))).quickGetValue(0, 0));
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock) {
        return execute(arrayList, arrayList2, matrixBlock, true, false, 0L);
    }

    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock, boolean z, boolean z2, long j) {
        if (arrayList == null || arrayList.size() < 1 || matrixBlock == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int numRows = arrayList.get(0).getNumRows();
        int numColumns = arrayList.get(0).getNumColumns();
        int minColsMatrixSideInputs = this._type.isConstDim2(this._constDim2) ? (int) this._constDim2 : (this._type.isRowTypeB1() || hasMatrixSideInput(arrayList)) ? getMinColsMatrixSideInputs(arrayList) : -1;
        if (!z2 || !matrixBlock.isAllocated()) {
            allocateOutputMatrix(numRows, numColumns, minColsMatrixSideInputs, matrixBlock);
        }
        DenseBlock denseBlock = matrixBlock.getDenseBlock();
        boolean z3 = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(matrixBlock.getNumRows(), matrixBlock.getNumColumns());
        SpoofOperator.SideInput[] prepInputMatrices = prepInputMatrices(arrayList, 1, arrayList.size() - 1, false, this._tB1);
        double[] prepInputScalars = prepInputScalars(arrayList2);
        if (z && this._reqVectMem > 0) {
            LibSpoofPrimitives.setupThreadLocalMemory(this._reqVectMem, numColumns, minColsMatrixSideInputs);
        }
        MatrixBlock matrixBlock2 = arrayList.get(0);
        if (matrixBlock2 instanceof CompressedMatrixBlock) {
            matrixBlock2 = CompressedMatrixBlock.getUncompressed(matrixBlock2);
        }
        if (matrixBlock2.isInSparseFormat()) {
            executeSparse(matrixBlock2.getSparseBlock(), prepInputMatrices, prepInputScalars, denseBlock, numColumns, 0, numRows, j);
        } else {
            executeDense(matrixBlock2.getDenseBlock(), prepInputMatrices, prepInputScalars, denseBlock, numColumns, 0, numRows, j);
        }
        if (z && this._reqVectMem > 0) {
            LibSpoofPrimitives.cleanupThreadLocalMemory();
        }
        if (z3) {
            fixTransposeDimensions(matrixBlock);
            matrixBlock = LibMatrixReorg.transpose(matrixBlock, new MatrixBlock(matrixBlock.getNumColumns(), matrixBlock.getNumRows(), false));
        }
        if (!z2) {
            matrixBlock.recomputeNonZeros();
            matrixBlock.examSparsity();
        }
        return matrixBlock;
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock, int i) {
        if (i <= 1 || ((this._type.isColumnAgg() && !LibMatrixMult.satisfiesMultiThreadingConstraints(arrayList.get(0), i)) || getTotalInputSize(arrayList) < 1048576)) {
            return execute(arrayList, arrayList2, matrixBlock);
        }
        if (arrayList == null || arrayList.size() < 1 || matrixBlock == null) {
            throw new RuntimeException("Invalid input arguments.");
        }
        int numRows = arrayList.get(0).getNumRows();
        int numColumns = arrayList.get(0).getNumColumns();
        int minColsMatrixSideInputs = this._type.isConstDim2(this._constDim2) ? (int) this._constDim2 : (this._type.isRowTypeB1() || hasMatrixSideInput(arrayList)) ? getMinColsMatrixSideInputs(arrayList) : -1;
        allocateOutputMatrix(numRows, numColumns, minColsMatrixSideInputs, matrixBlock);
        boolean z = this._type.isRowTypeB1ColumnAgg() && LibSpoofPrimitives.isFlipOuter(matrixBlock.getNumRows(), matrixBlock.getNumColumns());
        MatrixBlock matrixBlock2 = arrayList.get(0);
        SpoofOperator.SideInput[] prepInputMatrices = prepInputMatrices(arrayList, 1, arrayList.size() - 1, false, this._tB1);
        double[] prepInputScalars = prepInputScalars(arrayList2);
        ExecutorService executorService = CommonThreadPool.get(i);
        ArrayList<Integer> balancedBlockSizesDefault = UtilFunctions.getBalancedBlockSizesDefault(numRows, i, ((long) numRows) * ((long) numColumns) < 16777216);
        try {
            if (this._type.isColumnAgg() || this._type == RowType.FULL_AGG) {
                ArrayList arrayList3 = new ArrayList();
                int numRows2 = matrixBlock.getNumRows() * matrixBlock.getNumColumns();
                int i2 = 0;
                for (int i3 = 0; i3 < balancedBlockSizesDefault.size(); i3++) {
                    arrayList3.add(new ParColAggTask(matrixBlock2, prepInputMatrices, prepInputScalars, numColumns, minColsMatrixSideInputs, numRows2, i2, i2 + balancedBlockSizesDefault.get(i3).intValue()));
                    i2 += balancedBlockSizesDefault.get(i3).intValue();
                }
                List invokeAll = executorService.invokeAll(arrayList3);
                int numRows3 = this._type.isColumnAgg() ? matrixBlock.getNumRows() * matrixBlock.getNumColumns() : 1;
                Iterator it = invokeAll.iterator();
                while (it.hasNext()) {
                    LibMatrixMult.vectAdd(((DenseBlock) ((Future) it.next()).get()).valuesAt(0), matrixBlock.getDenseBlockValues(), 0, 0, numRows3);
                }
                matrixBlock.recomputeNonZeros();
            } else {
                ArrayList arrayList4 = new ArrayList();
                int i4 = 0;
                for (int i5 = 0; i5 < balancedBlockSizesDefault.size(); i5++) {
                    arrayList4.add(new ParExecTask(matrixBlock2, prepInputMatrices, matrixBlock, prepInputScalars, numColumns, minColsMatrixSideInputs, i4, i4 + balancedBlockSizesDefault.get(i5).intValue()));
                    i4 += balancedBlockSizesDefault.get(i5).intValue();
                }
                long j = 0;
                Iterator it2 = executorService.invokeAll(arrayList4).iterator();
                while (it2.hasNext()) {
                    j += ((Long) ((Future) it2.next()).get()).longValue();
                }
                matrixBlock.setNonZeros(j);
            }
            executorService.shutdown();
            if (z) {
                fixTransposeDimensions(matrixBlock);
                matrixBlock = LibMatrixReorg.transpose(matrixBlock, new MatrixBlock(matrixBlock.getNumColumns(), matrixBlock.getNumRows(), false));
            }
            matrixBlock.examSparsity();
            return matrixBlock;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static boolean hasMatrixSideInput(ArrayList<MatrixBlock> arrayList) {
        return IntStream.range(1, arrayList.size()).mapToObj(i -> {
            return (MatrixBlock) arrayList.get(i);
        }).anyMatch(matrixBlock -> {
            return matrixBlock.getNumColumns() > 1;
        });
    }

    protected static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> arrayList) {
        return IntStream.range(1, arrayList.size()).map(i -> {
            return ((MatrixBlock) arrayList.get(i)).getNumColumns();
        }).filter(i2 -> {
            return i2 > 1;
        }).min().orElse(1);
    }

    public static boolean hasMatrixObjectSideInput(ArrayList<MatrixObject> arrayList) {
        return IntStream.range(1, arrayList.size()).mapToObj(i -> {
            return (MatrixObject) arrayList.get(i);
        }).anyMatch(matrixObject -> {
            return matrixObject.getNumColumns() > 1;
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int getMinColsMatrixObjectSideInputs(ArrayList<MatrixObject> arrayList) {
        return IntStream.range(1, arrayList.size()).map(i -> {
            return (int) ((MatrixObject) arrayList.get(i)).getNumColumns();
        }).filter(i2 -> {
            return i2 > 1;
        }).min().orElse(1);
    }

    private void allocateOutputMatrix(int i, int i2, int i3, MatrixBlock matrixBlock) {
        OutputDimensions outputDimensions = new OutputDimensions(i, i2, i3);
        matrixBlock.reset(outputDimensions.rows, outputDimensions.cols, false);
        matrixBlock.allocateDenseBlock();
    }

    private static void fixTransposeDimensions(MatrixBlock matrixBlock) {
        int numRows = matrixBlock.getNumRows();
        matrixBlock.setNumRows(matrixBlock.getNumColumns());
        matrixBlock.setNumColumns(numRows);
        matrixBlock.setNonZeros(matrixBlock.getNumRows() * matrixBlock.getNumColumns());
    }

    private void executeDense(DenseBlock denseBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, DenseBlock denseBlock2, int i, int i2, int i3, long j) {
        if (denseBlock == null) {
            executeSparse(null, sideInputArr, dArr, denseBlock2, i, i2, i3, j);
            return;
        }
        SpoofOperator.SideInput[] createSparseSideInputs = createSparseSideInputs(sideInputArr, true);
        for (int i4 = i2; i4 < i3; i4++) {
            genexec(denseBlock.values(i4), denseBlock.pos(i4), createSparseSideInputs, dArr, denseBlock2.values(i4), denseBlock2.pos(i4), i, j + i4, i4);
        }
    }

    private void executeSparse(SparseBlock sparseBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, DenseBlock denseBlock, int i, int i2, int i3, long j) {
        SpoofOperator.SideInput[] createSparseSideInputs = createSparseSideInputs(sideInputArr, true);
        SparseRowVector sparseRowVector = new SparseRowVector(1);
        for (int i4 = i2; i4 < i3; i4++) {
            if (sparseBlock == null || sparseBlock.isEmpty(i4)) {
                genexec(sparseRowVector.values(), sparseRowVector.indexes(), 0, createSparseSideInputs, dArr, denseBlock.values(i4), denseBlock.pos(i4), 0, i, j + i4, i4);
            } else {
                genexec(sparseBlock.values(i4), sparseBlock.indexes(i4), sparseBlock.pos(i4), createSparseSideInputs, dArr, denseBlock.values(i4), denseBlock.pos(i4), sparseBlock.size(i4), i, j + i4, i4);
            }
        }
    }

    protected final void genexec(double[] dArr, int i, SpoofOperator.SideInput[] sideInputArr, double[] dArr2, double[] dArr3, int i2, int i3, int i4) {
        genexec(dArr, i, sideInputArr, dArr2, dArr3, i2, i3, i4, i4);
    }

    protected final void genexec(double[] dArr, int[] iArr, int i, SpoofOperator.SideInput[] sideInputArr, double[] dArr2, double[] dArr3, int i2, int i3, int i4, int i5) {
        genexec(dArr, iArr, i, sideInputArr, dArr2, dArr3, i2, i3, i4, i5, i5);
    }

    protected abstract void genexec(double[] dArr, int i, SpoofOperator.SideInput[] sideInputArr, double[] dArr2, double[] dArr3, int i2, int i3, long j, int i4);

    protected abstract void genexec(double[] dArr, int[] iArr, int i, SpoofOperator.SideInput[] sideInputArr, double[] dArr2, double[] dArr3, int i2, int i3, int i4, long j, int i5);
}
