package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.class */
public class ReorgFEDInstruction extends UnaryFEDInstruction {

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction$DiagMatrix.class */
    public static class DiagMatrix extends FederatedUDF {
        private static final long serialVersionUID = -3466926635958851402L;
        private final long _outputID;
        private final ReorgOperator _r_op;
        private final int _len;
        private final int[] _slice;
        private final boolean _rowFed;

        private DiagMatrix(long j, long j2, ReorgOperator reorgOperator, int[] iArr, boolean z, int i) {
            super(new long[]{j});
            this._outputID = j2;
            this._r_op = reorgOperator;
            this._len = i;
            this._rowFed = z;
            this._slice = iArr;
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            MatrixBlock matrixBlock;
            MatrixBlock acquireReadAndRelease = ((MatrixObject) dataArr[0]).acquireReadAndRelease();
            MatrixBlock reorgOperations = acquireReadAndRelease.reorgOperations(this._r_op, (MatrixValue) new MatrixBlock(), 0, 0, 0);
            if (this._rowFed) {
                matrixBlock = new MatrixBlock(acquireReadAndRelease.getNumRows(), this._len, DataExpression.DEFAULT_DELIM_FILL_VALUE);
                matrixBlock.copy(0, matrixBlock.getNumRows() - 1, this._slice[0], this._slice[1] - 1, reorgOperations, false);
            } else {
                matrixBlock = new MatrixBlock(this._len, this._slice[1], DataExpression.DEFAULT_DELIM_FILL_VALUE);
                matrixBlock.copy(this._slice[0], this._slice[1] - 1, 0, acquireReadAndRelease.getNumColumns() - 1, reorgOperations, false);
            }
            MatrixObject createMatrixObject = ExecutionContext.createMatrixObject(matrixBlock);
            createMatrixObject.setDiag(true);
            executionContext.setVariable(String.valueOf(this._outputID), createMatrixObject);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{matrixBlock.getNumRows(), matrixBlock.getNumColumns()});
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public List<Long> getOutputIds() {
            return new ArrayList(Arrays.asList(Long.valueOf(this._outputID)));
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return Pair.of(String.valueOf(this._outputID), new LineageItem(getClass().getSimpleName(), (LineageItem[]) Stream.concat(Arrays.stream((LineageItem[]) Arrays.stream(getInputIDs()).mapToObj(j -> {
                return executionContext.getLineage().get(String.valueOf(j));
            }).toArray(i -> {
                return new LineageItem[i];
            })), Arrays.stream(LineageItemUtils.getLineage(executionContext, new CPOperand(this._r_op.fn.getClass().getSimpleName(), Types.ValueType.STRING, Types.DataType.SCALAR, true), new CPOperand(String.valueOf(this._len), Types.ValueType.INT32, Types.DataType.SCALAR, true), new CPOperand(Arrays.toString(this._slice), Types.ValueType.STRING, Types.DataType.SCALAR, true), new CPOperand(String.valueOf(this._rowFed), Types.ValueType.BOOLEAN, Types.DataType.SCALAR, true)))).toArray(i2 -> {
                return new LineageItem[i2];
            })));
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction$Rdiag.class */
    public static class Rdiag extends FederatedUDF {
        private static final long serialVersionUID = -3466926635958851402L;
        private final long _outputID;
        private final ReorgOperator _r_op;
        private final int[] _slice;
        private final boolean _rowFed;

        private Rdiag(long j, long j2, ReorgOperator reorgOperator, int[] iArr, boolean z) {
            super(new long[]{j});
            this._outputID = j2;
            this._r_op = reorgOperator;
            this._slice = iArr;
            this._rowFed = z;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r12v0, types: [org.apache.sysds.runtime.matrix.data.MatrixBlock] */
        /* JADX WARN: Type inference failed for: r9v0, types: [org.apache.sysds.runtime.controlprogram.context.ExecutionContext] */
        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            MatrixBlock acquireReadAndRelease = ((MatrixObject) dataArr[0]).acquireReadAndRelease();
            MatrixBlock reorgOperations = (this._rowFed ? acquireReadAndRelease.slice(0, acquireReadAndRelease.getNumRows() - 1, this._slice[0], this._slice[1] - 1, new MatrixBlock()) : acquireReadAndRelease.slice2(this._slice[0], this._slice[1] - 1)).reorgOperations(this._r_op, new MatrixBlock(), 0, 0, 0);
            MatrixObject createMatrixObject = ExecutionContext.createMatrixObject(reorgOperations);
            createMatrixObject.setDiag(true);
            executionContext.setVariable(String.valueOf(this._outputID), createMatrixObject);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{reorgOperations.getNumRows(), reorgOperations.getNumColumns()});
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public List<Long> getOutputIds() {
            return new ArrayList(Arrays.asList(Long.valueOf(this._outputID)));
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return Pair.of(String.valueOf(this._outputID), new LineageItem(getClass().getSimpleName(), (LineageItem[]) Stream.concat(Arrays.stream((LineageItem[]) Arrays.stream(getInputIDs()).mapToObj(j -> {
                return executionContext.getLineage().get(String.valueOf(j));
            }).toArray(i -> {
                return new LineageItem[i];
            })), Arrays.stream(LineageItemUtils.getLineage(executionContext, new CPOperand(this._r_op.fn.getClass().getSimpleName(), Types.ValueType.STRING, Types.DataType.SCALAR, true), new CPOperand(Arrays.toString(this._slice), Types.ValueType.STRING, Types.DataType.SCALAR, true), new CPOperand(String.valueOf(this._rowFed), Types.ValueType.BOOLEAN, Types.DataType.SCALAR, true)))).toArray(i2 -> {
                return new LineageItem[i2];
            })));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction$RdiagResult.class */
    public class RdiagResult {
        FederationMap fedMap;
        Map<FederatedRange, int[]> dcs;

        public RdiagResult(FederationMap federationMap, Map<FederatedRange, int[]> map) {
            this.fedMap = federationMap;
            this.dcs = map;
        }

        public FederationMap getFedMap() {
            return this.fedMap;
        }

        public Map<FederatedRange, int[]> getDcs() {
            return this.dcs;
        }
    }

    public ReorgFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.Reorg, operator, cPOperand, cPOperand2, str, str2, federatedOutput);
    }

    public ReorgFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2) {
        super(FEDInstruction.FEDType.Reorg, operator, cPOperand, cPOperand2, str, str2);
    }

    public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction reorgCPInstruction) {
        return new ReorgFEDInstruction(reorgCPInstruction.getOperator(), reorgCPInstruction.input1, reorgCPInstruction.output, reorgCPInstruction.getOpcode(), reorgCPInstruction.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction reorgSPInstruction) {
        return new ReorgFEDInstruction(reorgSPInstruction.getOperator(), reorgSPInstruction.input1, reorgSPInstruction.output, reorgSPInstruction.getOpcode(), reorgSPInstruction.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static ReorgFEDInstruction parseInstruction(String str) {
        CPOperand cPOperand = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("r'")) {
            InstructionUtils.checkNumFields(str, 2, 3, 4);
            cPOperand.split(instructionPartsWithValueType[1]);
            cPOperand2.split(instructionPartsWithValueType[2]);
            return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(instructionPartsWithValueType[3])), cPOperand, cPOperand2, str2, str, str.startsWith(Types.ExecMode.SPARK.name()) ? FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[3]) : FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[4]));
        }
        if (str2.equalsIgnoreCase("rdiag")) {
            parseUnaryInstruction(str, cPOperand, cPOperand2);
            return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), cPOperand, cPOperand2, str2, str, parseFedOutFlag(str, 3));
        }
        if (!str2.equalsIgnoreCase("rev")) {
            throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + str2);
        }
        parseUnaryInstruction(str, cPOperand, cPOperand2);
        return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), cPOperand, cPOperand2, str2, str, parseFedOutFlag(str, 3));
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        ReorgOperator reorgOperator = (ReorgOperator) this._optr;
        boolean startsWith = this.instString.startsWith("SPARK");
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Federated Reorg: Federated input expected, but invoked w/ " + matrixObject.isFederated());
        }
        if (!matrixObject.isFederated(FTypes.FType.COL) && !matrixObject.isFederated(FTypes.FType.ROW)) {
            throw new DMLRuntimeException("Federation type " + matrixObject.getFedMapping().getType() + " is not supported for Reorg processing");
        }
        if (this.instOpcode.equals("r'")) {
            long nextFedDataID = FederationUtils.getNextFedDataID();
            FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, new MatrixCharacteristics(-1L, -1L), matrixObject.getDataType());
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()}, startsWith ? Types.ExecType.SPARK : Types.ExecType.CP, true);
            Future<FederatedResponse>[] execute = matrixObject.getFedMapping().execute(getTID(), true, federatedRequest, callInstruction);
            if (this._fedOut == null || this._fedOut.isForcedLocal()) {
                executionContext.setMatrixOutput(this.output.getName(), FederationUtils.bind(matrixObject.getFedMapping().execute(getTID(), true, callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID())), matrixObject.isFederated(FTypes.FType.ROW)));
                return;
            }
            MatrixObject matrixObject2 = executionContext.getMatrixObject(this.output);
            matrixObject2.getDataCharacteristics().setDimension(matrixObject.getNumColumns(), matrixObject.getNumRows()).setBlocksize(matrixObject.getBlocksize()).setNonZeros(matrixObject.getNnz() != -1 ? matrixObject.getNnz() : FederationUtils.sumNonZeros(execute));
            matrixObject2.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()).transpose());
            return;
        }
        if (matrixObject.isFederated(FTypes.FType.PART)) {
            throw new DMLRuntimeException("Operation with opcode " + this.instOpcode + " is not supported with PART input");
        }
        if (!this.instOpcode.equalsIgnoreCase("rev")) {
            if (this.instOpcode.equals("rdiag")) {
                FederationMap updateFedRanges = updateFedRanges((matrixObject.getNumColumns() != 1 || matrixObject.getNumRows() == 1) ? rdiagM2V(matrixObject, reorgOperator) : rdiagV2M(matrixObject, reorgOperator));
                MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
                matrixObject3.getDataCharacteristics().set(updateFedRanges.getMaxIndexInRange(0), updateFedRanges.getMaxIndexInRange(1), matrixObject.getBlocksize());
                matrixObject3.setFedMapping(updateFedRanges);
                optionalForceLocal(matrixObject3);
                return;
            }
            return;
        }
        long nextFedDataID2 = FederationUtils.getNextFedDataID();
        FederatedRequest federatedRequest2 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID2, new MatrixCharacteristics(-1L, -1L), matrixObject.getDataType());
        FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID2, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()}, startsWith ? Types.ExecType.SPARK : Types.ExecType.CP, true);
        Future<FederatedResponse>[] execute2 = matrixObject.getFedMapping().execute(getTID(), true, federatedRequest2, callInstruction2);
        if (matrixObject.isFederated(FTypes.FType.ROW)) {
            matrixObject.getFedMapping().reverseFedMap();
        }
        MatrixObject matrixObject4 = executionContext.getMatrixObject(this.output);
        matrixObject4.getDataCharacteristics().setDimension(matrixObject.getNumRows(), matrixObject.getNumColumns()).setBlocksize(matrixObject.getBlocksize()).setNonZeros(matrixObject.getNnz() != -1 ? matrixObject.getNnz() : FederationUtils.sumNonZeros(execute2));
        matrixObject4.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction2.getID()));
        optionalForceLocal(matrixObject4);
    }

    private FederationMap updateFedRanges(RdiagResult rdiagResult) {
        FederationMap fedMap = rdiagResult.getFedMap();
        Map<FederatedRange, int[]> dcs = rdiagResult.getDcs();
        int i = 0;
        while (i < fedMap.getFederatedRanges().length) {
            int[] iArr = dcs.get(fedMap.getFederatedRanges()[i]);
            fedMap.getFederatedRanges()[i].setBeginDim(0, (fedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 || i == 0) ? 0L : fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
            fedMap.getFederatedRanges()[i].setEndDim(0, fedMap.getFederatedRanges()[i].getBeginDims()[0] + iArr[0]);
            fedMap.getFederatedRanges()[i].setBeginDim(1, (fedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 || i == 0) ? 0L : fedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
            fedMap.getFederatedRanges()[i].setEndDim(1, fedMap.getFederatedRanges()[i].getBeginDims()[1] + iArr[1]);
            i++;
        }
        return fedMap;
    }

    private void optionalForceLocal(MatrixObject matrixObject) {
        if (this._fedOut == null || !this._fedOut.isForcedLocal()) {
            return;
        }
        matrixObject.acquireReadAndRelease();
        matrixObject.getFedMapping().cleanup(getTID(), matrixObject.getFedMapping().getID());
    }

    private RdiagResult rdiagV2M(MatrixObject matrixObject, ReorgOperator reorgOperator) {
        FederationMap fedMapping = matrixObject.getFedMapping();
        boolean isFederated = matrixObject.isFederated(FTypes.FType.ROW);
        long nextFedDataID = FederationUtils.getNextFedDataID();
        HashMap hashMap = new HashMap();
        return new RdiagResult(fedMapping.mapParallel(nextFedDataID, (federatedRange, federatedData) -> {
            try {
                FederatedRequest[] federatedRequestArr = new FederatedRequest[1];
                FederatedRequest.RequestType requestType = FederatedRequest.RequestType.EXEC_UDF;
                Object[] objArr = new Object[1];
                objArr[0] = new DiagMatrix(federatedData.getVarID(), nextFedDataID, reorgOperator, isFederated ? new int[]{federatedRange.getBeginDimsInt()[0], federatedRange.getEndDimsInt()[0]} : new int[]{federatedRange.getBeginDimsInt()[1], federatedRange.getEndDimsInt()[1]}, isFederated, (int) matrixObject.getNumRows());
                federatedRequestArr[0] = new FederatedRequest(requestType, -1L, objArr);
                FederatedResponse federatedResponse = federatedData.executeFederatedOperation(federatedRequestArr).get();
                if (!federatedResponse.isSuccessful()) {
                    federatedResponse.throwExceptionFromResponse();
                }
                int[] iArr = (int[]) federatedResponse.getData()[0];
                synchronized (hashMap) {
                    hashMap.put(federatedRange, iArr);
                }
                return null;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }), hashMap);
    }

    private RdiagResult rdiagM2V(MatrixObject matrixObject, ReorgOperator reorgOperator) {
        FederationMap fedMapping = matrixObject.getFedMapping();
        boolean isFederated = matrixObject.isFederated(FTypes.FType.ROW);
        long nextFedDataID = FederationUtils.getNextFedDataID();
        HashMap hashMap = new HashMap();
        return new RdiagResult(fedMapping.mapParallel(nextFedDataID, (federatedRange, federatedData) -> {
            try {
                FederatedRequest[] federatedRequestArr = new FederatedRequest[1];
                FederatedRequest.RequestType requestType = FederatedRequest.RequestType.EXEC_UDF;
                Object[] objArr = new Object[1];
                objArr[0] = new Rdiag(federatedData.getVarID(), nextFedDataID, reorgOperator, isFederated ? new int[]{federatedRange.getBeginDimsInt()[0], federatedRange.getEndDimsInt()[0]} : new int[]{federatedRange.getBeginDimsInt()[1], federatedRange.getEndDimsInt()[1]}, isFederated);
                federatedRequestArr[0] = new FederatedRequest(requestType, -1L, objArr);
                FederatedResponse federatedResponse = federatedData.executeFederatedOperation(federatedRequestArr).get();
                if (!federatedResponse.isSuccessful()) {
                    federatedResponse.throwExceptionFromResponse();
                }
                int[] iArr = (int[]) federatedResponse.getData()[0];
                synchronized (hashMap) {
                    hashMap.put(federatedRange, iArr);
                }
                return null;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }), hashMap);
    }
}
