package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.class */
public class SpoofFEDInstruction extends FEDInstruction {
    private final SpoofOperator _op;
    private final CPOperand[] _inputs;
    private final CPOperand _output;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDCellwise.class */
    private static class SpoofFEDCellwise extends SpoofFEDType {
        private final SpoofCellwise _op;
        private final SpoofCellwise.CellType _cellType;

        SpoofFEDCellwise(SpoofOperator spoofOperator, CPOperand cPOperand, FederationMap.FType fType) {
            super(cPOperand, fType);
            this._op = (SpoofCellwise) spoofOperator;
            this._cellType = this._op.getCellType();
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean isFedOutput() {
            return false | (this._cellType == SpoofCellwise.CellType.ROW_AGG && this._fedType == FederationMap.FType.ROW) | (this._cellType == SpoofCellwise.CellType.COL_AGG && this._fedType == FederationMap.FType.COL) | (this._cellType == SpoofCellwise.CellType.NO_AGG);
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setFedOutput(ExecutionContext executionContext, FederationMap federationMap, long j) {
            executionContext.getMatrixObject(this._output).setFedMapping(modifyFedRanges(federationMap.copyWithNewID(j)));
        }

        private FederationMap modifyFedRanges(FederationMap federationMap) {
            if (this._cellType == SpoofCellwise.CellType.ROW_AGG || this._cellType == SpoofCellwise.CellType.COL_AGG) {
                int i = this._cellType == SpoofCellwise.CellType.COL_AGG ? 0 : 1;
                IntStream.range(0, federationMap.getFederatedRanges().length).forEach(i2 -> {
                    federationMap.getFederatedRanges()[i2].setBeginDim(i, 0L);
                    federationMap.getFederatedRanges()[i2].setEndDim(i, 1L);
                });
            }
            return federationMap;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void aggResult(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            String str;
            SpoofCellwise.AggOp aggOp = this._op.getAggOp();
            String str2 = "ua";
            switch (this._cellType) {
                case FULL_AGG:
                    break;
                case ROW_AGG:
                    str2 = str2 + GPUInstruction.MISC_TIMER_REUSE;
                    break;
                case COL_AGG:
                    str2 = str2 + "c";
                    break;
                case NO_AGG:
                default:
                    throw new DMLRuntimeException("Aggregation type not supported yet.");
            }
            switch (aggOp) {
                case SUM:
                case SUM_SQ:
                    str = str2 + "k+";
                    break;
                case MIN:
                    str = str2 + "min";
                    break;
                case MAX:
                    str = str2 + "max";
                    break;
                default:
                    throw new DMLRuntimeException("Aggregation operation not supported yet.");
            }
            AggregateUnaryOperator parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator(str);
            if (this._cellType == SpoofCellwise.CellType.FULL_AGG) {
                executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(parseBasicAggregateUnaryOperator, futureArr));
            } else {
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(parseBasicAggregateUnaryOperator, futureArr, federationMap));
            }
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDMultiAgg.class */
    private static class SpoofFEDMultiAgg extends SpoofFEDType {
        private final SpoofMultiAggregate _op;

        SpoofFEDMultiAgg(SpoofOperator spoofOperator, CPOperand cPOperand, FederationMap.FType fType) {
            super(cPOperand, fType);
            this._op = (SpoofMultiAggregate) spoofOperator;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean isFedOutput() {
            return false;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setFedOutput(ExecutionContext executionContext, FederationMap federationMap, long j) {
            throw new DMLRuntimeException("SpoofFEDMultiAgg cannot create a federated output.");
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void aggResult(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            MatrixBlock[] results = FederationUtils.getResults(futureArr);
            SpoofCellwise.AggOp[] aggOps = this._op.getAggOps();
            for (int i = 1; i < results.length; i++) {
                SpoofMultiAggregate.aggregatePartialResults(aggOps, results[0], results[i]);
            }
            executionContext.setMatrixOutput(this._output.getName(), results[0]);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDOuterProduct.class */
    private static class SpoofFEDOuterProduct extends SpoofFEDType {
        private final SpoofOuterProduct _op;
        private final SpoofOuterProduct.OutProdType _outProdType;
        private CPOperand[] _inputs;

        SpoofFEDOuterProduct(SpoofOperator spoofOperator, CPOperand cPOperand, FederationMap.FType fType, CPOperand[] cPOperandArr) {
            super(cPOperand, fType);
            this._op = (SpoofOuterProduct) spoofOperator;
            this._outProdType = this._op.getOuterProdType();
            this._inputs = cPOperandArr;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected FederatedRequest[] broadcastSliced(MatrixObject matrixObject, FederationMap federationMap) {
            return federationMap.broadcastSliced(matrixObject, this._fedType == FederationMap.FType.COL);
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean needsBroadcastSliced(FederationMap federationMap, long j, long j2, int i) {
            boolean z;
            boolean z2 = false | (j == federationMap.getMaxIndexInRange(0) && j2 == federationMap.getMaxIndexInRange(1));
            if (this._fedType == FederationMap.FType.ROW) {
                z = z2 | (j == federationMap.getMaxIndexInRange(0) && i != 2);
            } else {
                if (this._fedType != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
                }
                z = z2 | (j == federationMap.getMaxIndexInRange(1) && i != 1);
            }
            return z;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean isFedOutput() {
            return false | (this._outProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT && this._fedType == FederationMap.FType.COL) | (this._outProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT && this._fedType == FederationMap.FType.ROW) | (this._outProdType == SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT);
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setFedOutput(ExecutionContext executionContext, FederationMap federationMap, long j) {
            FederationMap copyWithNewID = federationMap.copyWithNewID(j);
            long[] jArr = new long[2];
            MatrixObject matrixObject = executionContext.getMatrixObject(this._inputs[0]);
            switch (this._outProdType) {
                case LEFT_OUTER_PRODUCT:
                    copyWithNewID = copyWithNewID.transpose();
                    jArr[0] = matrixObject.getNumColumns();
                    jArr[1] = executionContext.getMatrixObject(this._inputs[1]).getNumColumns();
                    break;
                case RIGHT_OUTER_PRODUCT:
                    jArr[0] = matrixObject.getNumRows();
                    jArr[1] = executionContext.getMatrixObject(this._inputs[2]).getNumColumns();
                    break;
                case CELLWISE_OUTER_PRODUCT:
                    jArr[0] = matrixObject.getNumRows();
                    jArr[1] = matrixObject.getNumColumns();
                    break;
                default:
                    throw new DMLRuntimeException("Outer Product Type " + this._outProdType + " not supported yet.");
            }
            MatrixObject matrixObject2 = executionContext.getMatrixObject(this._output);
            int i = copyWithNewID.getType() == FederationMap.FType.ROW ? 1 : 0;
            matrixObject2.setFedMapping(modifyFedRanges(copyWithNewID, i, jArr[i]));
        }

        private static FederationMap modifyFedRanges(FederationMap federationMap, int i, long j) {
            IntStream.range(0, federationMap.getFederatedRanges().length).forEach(i2 -> {
                federationMap.getFederatedRanges()[i2].setBeginDim(i, 0L);
                federationMap.getFederatedRanges()[i2].setEndDim(i, j);
            });
            return federationMap;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void aggResult(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            AggregateUnaryOperator parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
            switch (this._outProdType) {
                case LEFT_OUTER_PRODUCT:
                case RIGHT_OUTER_PRODUCT:
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(parseBasicAggregateUnaryOperator, futureArr, federationMap));
                    return;
                case CELLWISE_OUTER_PRODUCT:
                default:
                    throw new DMLRuntimeException("Outer Product Type " + this._outProdType + " not supported yet.");
                case AGG_OUTER_PRODUCT:
                    executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(parseBasicAggregateUnaryOperator, futureArr));
                    return;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDRowwise.class */
    private static class SpoofFEDRowwise extends SpoofFEDType {
        private final SpoofRowwise _op;
        private final SpoofRowwise.RowType _rowType;

        SpoofFEDRowwise(SpoofOperator spoofOperator, CPOperand cPOperand, FederationMap.FType fType) {
            super(cPOperand, fType);
            this._op = (SpoofRowwise) spoofOperator;
            this._rowType = this._op.getRowType();
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean isFedOutput() {
            return (false | (this._rowType == SpoofRowwise.RowType.NO_AGG) | (this._rowType == SpoofRowwise.RowType.NO_AGG_B1) | (this._rowType == SpoofRowwise.RowType.NO_AGG_CONST)) & (this._fedType == FederationMap.FType.ROW);
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setFedOutput(ExecutionContext executionContext, FederationMap federationMap, long j) {
            MatrixObject matrixObject = executionContext.getMatrixObject(this._output);
            matrixObject.setFedMapping(modifyFedRanges(federationMap.copyWithNewID(j), matrixObject.getNumColumns()));
        }

        private static FederationMap modifyFedRanges(FederationMap federationMap, long j) {
            IntStream.range(0, federationMap.getFederatedRanges().length).forEach(i -> {
                federationMap.getFederatedRanges()[i].setBeginDim(1, 0L);
                federationMap.getFederatedRanges()[i].setEndDim(1, j);
            });
            return federationMap;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void aggResult(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            String str;
            if (this._fedType != FederationMap.FType.ROW) {
                throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
            }
            if (this._rowType == SpoofRowwise.RowType.FULL_AGG) {
                str = "uak+";
            } else if (this._rowType == SpoofRowwise.RowType.ROW_AGG) {
                str = "uark+";
            } else {
                if (!this._rowType.isColumnAgg()) {
                    throw new DMLRuntimeException("AggregationType not supported yet.");
                }
                str = "uack+";
            }
            AggregateUnaryOperator parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator(str);
            if (this._rowType == SpoofRowwise.RowType.FULL_AGG) {
                executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(parseBasicAggregateUnaryOperator, futureArr));
            } else {
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(parseBasicAggregateUnaryOperator, futureArr, federationMap));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDType.class */
    public static abstract class SpoofFEDType {
        CPOperand _output;
        FederationMap.FType _fedType;

        protected SpoofFEDType(CPOperand cPOperand, FederationMap.FType fType) {
            this._output = cPOperand;
            this._fedType = fType;
        }

        protected FederatedRequest[] broadcastSliced(MatrixObject matrixObject, FederationMap federationMap) {
            return federationMap.broadcastSliced(matrixObject, false);
        }

        protected boolean needsBroadcastSliced(FederationMap federationMap, long j, long j2, int i) {
            boolean z;
            boolean z2 = j == federationMap.getMaxIndexInRange(0) && j2 == federationMap.getMaxIndexInRange(1);
            if (this._fedType == FederationMap.FType.ROW) {
                z = z2 | (j == federationMap.getMaxIndexInRange(0) && (j2 == 1 || federationMap.getMaxIndexInRange(1) == 1));
            } else {
                if (this._fedType != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
                }
                z = z2 | (j2 == federationMap.getMaxIndexInRange(1) && (j == 1 || federationMap.getMaxIndexInRange(0) == 1));
            }
            return z;
        }

        protected void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap, long j) {
            if (isFedOutput()) {
                setFedOutput(executionContext, federationMap, j);
            } else {
                aggResult(executionContext, futureArr, federationMap);
            }
        }

        protected abstract boolean isFedOutput();

        protected abstract void setFedOutput(ExecutionContext executionContext, FederationMap federationMap, long j);

        protected abstract void aggResult(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap);
    }

    private SpoofFEDInstruction(SpoofOperator spoofOperator, CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(FEDInstruction.FEDType.SpoofFused, str, str2);
        this._op = spoofOperator;
        this._inputs = cPOperandArr;
        this._output = cPOperand;
    }

    public static SpoofFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        CPOperand[] cPOperandArr = new CPOperand[(instructionPartsWithValueType.length - 3) - 2];
        SpoofOperator createInstance = CodegenUtils.createInstance(CodegenUtils.getClass(instructionPartsWithValueType[2]));
        String str2 = instructionPartsWithValueType[0] + createInstance.getSpoofType();
        for (int i = 3; i < instructionPartsWithValueType.length - 2; i++) {
            cPOperandArr[i - 3] = new CPOperand(instructionPartsWithValueType[i]);
        }
        return new SpoofFEDInstruction(createInstance, cPOperandArr, new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 2]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SpoofFEDType spoofFEDOuterProduct;
        FederationMap federationMap = null;
        CPOperand[] cPOperandArr = this._inputs;
        int length = cPOperandArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            Data variable = executionContext.getVariable(cPOperandArr[i]);
            if ((variable instanceof MatrixObject) && ((MatrixObject) variable).isFederatedExcept(FederationMap.FType.BROADCAST)) {
                federationMap = ((MatrixObject) variable).getFedMapping();
                break;
            }
            i++;
        }
        Class<? super Object> superclass = this._op.getClass().getSuperclass();
        if (superclass == SpoofCellwise.class) {
            spoofFEDOuterProduct = new SpoofFEDCellwise(this._op, this._output, federationMap.getType());
        } else if (superclass == SpoofRowwise.class) {
            spoofFEDOuterProduct = new SpoofFEDRowwise(this._op, this._output, federationMap.getType());
        } else if (superclass == SpoofMultiAggregate.class) {
            spoofFEDOuterProduct = new SpoofFEDMultiAgg(this._op, this._output, federationMap.getType());
        } else {
            if (superclass != SpoofOuterProduct.class) {
                throw new DMLRuntimeException("Federated code generation only supported for cellwise, rowwise, multiaggregate, and outerproduct templates.");
            }
            spoofFEDOuterProduct = new SpoofFEDOuterProduct(this._op, this._output, federationMap.getType(), this._inputs);
        }
        processRequest(executionContext, federationMap, spoofFEDOuterProduct);
    }

    private void processRequest(ExecutionContext executionContext, FederationMap federationMap, SpoofFEDType spoofFEDType) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        long[] jArr = new long[this._inputs.length];
        int i = 0;
        for (CPOperand cPOperand : this._inputs) {
            Data variable = executionContext.getVariable(cPOperand);
            if (variable instanceof MatrixObject) {
                MatrixObject matrixObject = (MatrixObject) variable;
                if (matrixObject.isFederatedExcept(FederationMap.FType.BROADCAST)) {
                    int i2 = i;
                    i++;
                    jArr[i2] = matrixObject.getFedMapping().getID();
                } else if (spoofFEDType.needsBroadcastSliced(federationMap, matrixObject.getNumRows(), matrixObject.getNumColumns(), i)) {
                    FederatedRequest[] broadcastSliced = spoofFEDType.broadcastSliced(matrixObject, federationMap);
                    int i3 = i;
                    i++;
                    jArr[i3] = broadcastSliced[0].getID();
                    arrayList2.add(broadcastSliced);
                } else {
                    FederatedRequest broadcast = federationMap.broadcast(matrixObject);
                    int i4 = i;
                    i++;
                    jArr[i4] = broadcast.getID();
                    arrayList.add(broadcast);
                }
            } else if (variable instanceof ScalarObject) {
                FederatedRequest broadcast2 = federationMap.broadcast((ScalarObject) variable);
                int i5 = i;
                i++;
                jArr[i5] = broadcast2.getID();
                arrayList.add(broadcast2);
            }
        }
        this.instString = this.instString.replace("true", "false");
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this._output, this._inputs, jArr);
        FederatedRequest federatedRequest = null;
        FederatedRequest federatedRequest2 = null;
        if (!spoofFEDType.isFedOutput()) {
            federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID());
            federatedRequest2 = federationMap.cleanup(getTID(), callInstruction.getID());
        }
        spoofFEDType.setOutput(executionContext, federationMap.executeMultipleSlices(getTID(), true, (FederatedRequest[][]) arrayList2.toArray(new FederatedRequest[0]), federatedRequest == null ? (FederatedRequest[]) ArrayUtils.addAll(arrayList.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction}) : (FederatedRequest[]) ArrayUtils.addAll(arrayList.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction, federatedRequest, federatedRequest2})), federationMap, callInstruction.getID());
    }

    public static boolean isFederated(ExecutionContext executionContext, CPOperand[] cPOperandArr, Class<?> cls) {
        return isFederated(executionContext, null, cPOperandArr, cls);
    }

    public static boolean isFederated(ExecutionContext executionContext, FederationMap.FType fType, CPOperand[] cPOperandArr, Class<?> cls) {
        FederationMap federationMap = null;
        boolean z = false;
        ArrayList arrayList = new ArrayList();
        for (CPOperand cPOperand : cPOperandArr) {
            Data variable = executionContext.getVariable(cPOperand);
            if ((variable instanceof MatrixObject) && ((MatrixObject) variable).isFederated(fType) && !((MatrixObject) variable).isFederated(FederationMap.FType.BROADCAST)) {
                MatrixObject matrixObject = (MatrixObject) variable;
                if (federationMap == null) {
                    federationMap = matrixObject.getFedMapping();
                    z = true;
                    arrayList.add(matrixObject.isFederated(FederationMap.FType.ROW) ? FederationMap.AlignType.ROW : FederationMap.AlignType.COL);
                    if (cls == SpoofOuterProduct.class) {
                        Collections.addAll(arrayList, FederationMap.AlignType.ROW_T, FederationMap.AlignType.COL_T);
                    }
                } else if (!federationMap.isAligned(matrixObject.getFedMapping(), (FederationMap.AlignType[]) arrayList.toArray(new FederationMap.AlignType[0]))) {
                    z = false;
                }
            }
        }
        return z;
    }
}
