package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.class */
public class SpoofFEDInstruction extends FEDInstruction {
    private final SpoofOperator _op;
    private final CPOperand[] _inputs;
    private final CPOperand _output;

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDCellwise.class */
    private static class SpoofFEDCellwise extends SpoofFEDType {
        private final SpoofCellwise _op;

        SpoofFEDCellwise(SpoofOperator spoofOperator, CPOperand cPOperand) {
            super(cPOperand);
            this._op = (SpoofCellwise) spoofOperator;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            AggregateUnaryOperator parseBasicAggregateUnaryOperator;
            AggregateUnaryOperator parseBasicAggregateUnaryOperator2;
            AggregateUnaryOperator parseBasicAggregateUnaryOperator3;
            FederationMap.FType type = federationMap.getType();
            SpoofCellwise.AggOp aggOp = this._op.getAggOp();
            SpoofCellwise.CellType cellType = this._op.getCellType();
            if (cellType == SpoofCellwise.CellType.FULL_AGG) {
                if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                    parseBasicAggregateUnaryOperator3 = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                    parseBasicAggregateUnaryOperator3 = InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
                } else {
                    if (aggOp != SpoofCellwise.AggOp.MAX) {
                        throw new DMLRuntimeException("Aggregation operation not supported yet.");
                    }
                    parseBasicAggregateUnaryOperator3 = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
                }
                executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(parseBasicAggregateUnaryOperator3, futureArr));
                return;
            }
            if (cellType == SpoofCellwise.CellType.ROW_AGG) {
                if (type == FederationMap.FType.ROW) {
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
                    return;
                }
                if (type != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
                }
                if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                    parseBasicAggregateUnaryOperator2 = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
                } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                    parseBasicAggregateUnaryOperator2 = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
                } else {
                    if (aggOp != SpoofCellwise.AggOp.MAX) {
                        throw new DMLRuntimeException("Aggregation operation not supported yet.");
                    }
                    parseBasicAggregateUnaryOperator2 = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
                }
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(parseBasicAggregateUnaryOperator2, futureArr, federationMap));
                return;
            }
            if (cellType != SpoofCellwise.CellType.COL_AGG) {
                if (cellType != SpoofCellwise.CellType.NO_AGG) {
                    throw new DMLRuntimeException("Aggregation type not supported yet.");
                }
                if (type == FederationMap.FType.ROW) {
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
                    return;
                } else {
                    if (type != FederationMap.FType.COL) {
                        throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    }
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, true));
                    return;
                }
            }
            if (type != FederationMap.FType.ROW) {
                if (type != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
                }
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, true));
                return;
            }
            if (aggOp == SpoofCellwise.AggOp.SUM || aggOp == SpoofCellwise.AggOp.SUM_SQ) {
                parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
            } else if (aggOp == SpoofCellwise.AggOp.MIN) {
                parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
            } else {
                if (aggOp != SpoofCellwise.AggOp.MAX) {
                    throw new DMLRuntimeException("Aggregation operation not supported yet.");
                }
                parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
            }
            executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(parseBasicAggregateUnaryOperator, futureArr, federationMap));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDMultiAgg.class */
    private static class SpoofFEDMultiAgg extends SpoofFEDType {
        private final SpoofMultiAggregate _op;

        SpoofFEDMultiAgg(SpoofOperator spoofOperator, CPOperand cPOperand) {
            super(cPOperand);
            this._op = (SpoofMultiAggregate) spoofOperator;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            MatrixBlock[] results = FederationUtils.getResults(futureArr);
            SpoofCellwise.AggOp[] aggOps = this._op.getAggOps();
            for (int i = 1; i < results.length; i++) {
                SpoofMultiAggregate.aggregatePartialResults(aggOps, results[0], results[i]);
            }
            executionContext.setMatrixOutput(this._output.getName(), results[0]);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDOuterProduct.class */
    private static class SpoofFEDOuterProduct extends SpoofFEDType {
        private final SpoofOuterProduct _op;

        SpoofFEDOuterProduct(SpoofOperator spoofOperator, CPOperand cPOperand) {
            super(cPOperand);
            this._op = (SpoofOuterProduct) spoofOperator;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected FederatedRequest[] broadcastSliced(MatrixObject matrixObject, FederationMap federationMap) {
            return federationMap.broadcastSliced(matrixObject, federationMap.getType() == FederationMap.FType.COL);
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected boolean needsBroadcastSliced(FederationMap federationMap, long j, long j2, int i) {
            boolean z;
            FederationMap.FType type = federationMap.getType();
            boolean z2 = false | (j == federationMap.getMaxIndexInRange(0) && j2 == federationMap.getMaxIndexInRange(1));
            if (type == FederationMap.FType.ROW) {
                z = z2 | (j == federationMap.getMaxIndexInRange(0) && i != 2);
            } else {
                if (type != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
                }
                z = z2 | (j == federationMap.getMaxIndexInRange(1) && i != 1);
            }
            return z;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            FederationMap.FType type = federationMap.getType();
            SpoofOuterProduct.OutProdType outerProdType = this._op.getOuterProdType();
            if (outerProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT) {
                if (type == FederationMap.FType.ROW) {
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), futureArr, federationMap));
                    return;
                } else {
                    if (type != FederationMap.FType.COL) {
                        throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    }
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
                    return;
                }
            }
            if (outerProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                if (type == FederationMap.FType.ROW) {
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
                    return;
                } else {
                    if (type != FederationMap.FType.COL) {
                        throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                    }
                    executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), futureArr, federationMap));
                    return;
                }
            }
            if (outerProdType != SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT) {
                if (outerProdType != SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT) {
                    throw new DMLRuntimeException("Outer Product Type " + outerProdType + " not supported yet.");
                }
                executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), futureArr));
                return;
            }
            if (type == FederationMap.FType.ROW) {
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
            } else {
                if (type != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Only row partitioned or column partitioned federated matrices supported yet.");
                }
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, true));
            }
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDRowwise.class */
    private static class SpoofFEDRowwise extends SpoofFEDType {
        private final SpoofRowwise _op;

        SpoofFEDRowwise(SpoofOperator spoofOperator, CPOperand cPOperand) {
            super(cPOperand);
            this._op = (SpoofRowwise) spoofOperator;
        }

        @Override // org.apache.sysds.runtime.instructions.fed.SpoofFEDInstruction.SpoofFEDType
        protected void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
            SpoofRowwise.RowType rowType = this._op.getRowType();
            if (rowType == SpoofRowwise.RowType.FULL_AGG) {
                executionContext.setVariable(this._output.getName(), FederationUtils.aggScalar(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), futureArr));
                return;
            }
            if (rowType == SpoofRowwise.RowType.ROW_AGG) {
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator("uark+"), futureArr, federationMap));
                return;
            }
            if (rowType == SpoofRowwise.RowType.COL_AGG || rowType == SpoofRowwise.RowType.COL_AGG_T || rowType == SpoofRowwise.RowType.COL_AGG_B1 || rowType == SpoofRowwise.RowType.COL_AGG_B1_T || rowType == SpoofRowwise.RowType.COL_AGG_B1R || rowType == SpoofRowwise.RowType.COL_AGG_CONST) {
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator("uack+"), futureArr, federationMap));
            } else {
                if (rowType != SpoofRowwise.RowType.NO_AGG && rowType != SpoofRowwise.RowType.NO_AGG_B1 && rowType != SpoofRowwise.RowType.NO_AGG_CONST) {
                    throw new DMLRuntimeException("AggregationType not supported yet.");
                }
                if (federationMap.getType() != FederationMap.FType.ROW) {
                    throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
                }
                executionContext.setMatrixOutput(this._output.getName(), FederationUtils.bind(futureArr, false));
            }
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction$SpoofFEDType.class */
    private static abstract class SpoofFEDType {
        CPOperand _output;

        protected SpoofFEDType(CPOperand cPOperand) {
            this._output = cPOperand;
        }

        protected FederatedRequest[] broadcastSliced(MatrixObject matrixObject, FederationMap federationMap) {
            return federationMap.broadcastSliced(matrixObject, false);
        }

        protected boolean needsBroadcastSliced(FederationMap federationMap, long j, long j2, int i) {
            boolean z;
            FederationMap.FType type = federationMap.getType();
            boolean z2 = j == federationMap.getMaxIndexInRange(0) && j2 == federationMap.getMaxIndexInRange(1);
            if (type == FederationMap.FType.ROW) {
                z = z2 | (j == federationMap.getMaxIndexInRange(0) && (j2 == 1 || j2 == ((long) federationMap.getSize()) || federationMap.getMaxIndexInRange(1) == 1));
            } else {
                if (type != FederationMap.FType.COL) {
                    throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
                }
                z = z2 | (j2 == federationMap.getMaxIndexInRange(1) && (j == 1 || j == ((long) federationMap.getSize()) || federationMap.getMaxIndexInRange(0) == 1));
            }
            return z;
        }

        protected abstract void setOutput(ExecutionContext executionContext, Future<FederatedResponse>[] futureArr, FederationMap federationMap);
    }

    private SpoofFEDInstruction(SpoofOperator spoofOperator, CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(FEDInstruction.FEDType.SpoofFused, str, str2);
        this._op = spoofOperator;
        this._inputs = cPOperandArr;
        this._output = cPOperand;
    }

    public static SpoofFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        CPOperand[] cPOperandArr = new CPOperand[(instructionPartsWithValueType.length - 3) - 2];
        SpoofOperator createInstance = CodegenUtils.createInstance(CodegenUtils.getClass(instructionPartsWithValueType[2]));
        String str2 = instructionPartsWithValueType[0] + createInstance.getSpoofType();
        for (int i = 3; i < instructionPartsWithValueType.length - 2; i++) {
            cPOperandArr[i - 3] = new CPOperand(instructionPartsWithValueType[i]);
        }
        return new SpoofFEDInstruction(createInstance, cPOperandArr, new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 2]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SpoofFEDType spoofFEDOuterProduct;
        Class<? super Object> superclass = this._op.getClass().getSuperclass();
        if (superclass == SpoofCellwise.class) {
            spoofFEDOuterProduct = new SpoofFEDCellwise(this._op, this._output);
        } else if (superclass == SpoofRowwise.class) {
            spoofFEDOuterProduct = new SpoofFEDRowwise(this._op, this._output);
        } else if (superclass == SpoofMultiAggregate.class) {
            spoofFEDOuterProduct = new SpoofFEDMultiAgg(this._op, this._output);
        } else {
            if (superclass != SpoofOuterProduct.class) {
                throw new DMLRuntimeException("Federated code generation only supported for cellwise, rowwise, multiaggregate, and outerproduct templates.");
            }
            spoofFEDOuterProduct = new SpoofFEDOuterProduct(this._op, this._output);
        }
        FederationMap federationMap = null;
        CPOperand[] cPOperandArr = this._inputs;
        int length = cPOperandArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            Data variable = executionContext.getVariable(cPOperandArr[i]);
            if ((variable instanceof MatrixObject) && ((MatrixObject) variable).isFederated()) {
                federationMap = ((MatrixObject) variable).getFedMapping();
                break;
            }
            i++;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        long[] jArr = new long[this._inputs.length];
        int i2 = 0;
        for (CPOperand cPOperand : this._inputs) {
            Data variable2 = executionContext.getVariable(cPOperand);
            if (variable2 instanceof MatrixObject) {
                MatrixObject matrixObject = (MatrixObject) variable2;
                if (matrixObject.isFederated()) {
                    int i3 = i2;
                    i2++;
                    jArr[i3] = matrixObject.getFedMapping().getID();
                } else if (spoofFEDOuterProduct.needsBroadcastSliced(federationMap, matrixObject.getNumRows(), matrixObject.getNumColumns(), i2)) {
                    FederatedRequest[] broadcastSliced = spoofFEDOuterProduct.broadcastSliced(matrixObject, federationMap);
                    int i4 = i2;
                    i2++;
                    jArr[i4] = broadcastSliced[0].getID();
                    arrayList2.add(broadcastSliced);
                } else {
                    FederatedRequest broadcast = federationMap.broadcast(matrixObject);
                    int i5 = i2;
                    i2++;
                    jArr[i5] = broadcast.getID();
                    arrayList.add(broadcast);
                }
            } else if (variable2 instanceof ScalarObject) {
                FederatedRequest broadcast2 = federationMap.broadcast((ScalarObject) variable2);
                int i6 = i2;
                i2++;
                jArr[i6] = broadcast2.getID();
                arrayList.add(broadcast2);
            }
        }
        this.instString = this.instString.replace("true", "false");
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this._output, this._inputs, jArr);
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID());
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(federationMap.cleanup(getTID(), callInstruction.getID()));
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList3.add(federationMap.cleanup(getTID(), ((FederatedRequest) it.next()).getID()));
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            arrayList3.add(federationMap.cleanup(getTID(), ((FederatedRequest[]) it2.next())[0].getID()));
        }
        spoofFEDOuterProduct.setOutput(executionContext, federationMap.executeMultipleSlices(getTID(), true, (FederatedRequest[][]) arrayList2.toArray(new FederatedRequest[0]), (FederatedRequest[]) ArrayUtils.addAll(ArrayUtils.addAll(arrayList.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction, federatedRequest}), arrayList3.toArray(new FederatedRequest[0]))), federationMap);
    }
}
