package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Objects;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.class */
public final class IndexingFEDInstruction extends UnaryFEDInstruction {
    protected final CPOperand rowLower;
    protected final CPOperand rowUpper;
    protected final CPOperand colLower;
    protected final CPOperand colUpper;

    protected IndexingFEDInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, String str, String str2) {
        super(FEDInstruction.FEDType.MatrixIndexing, null, cPOperand, cPOperand6, str, str2);
        this.rowLower = cPOperand2;
        this.rowUpper = cPOperand3;
        this.colLower = cPOperand4;
        this.colUpper = cPOperand5;
    }

    protected IndexingFEDInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, String str, String str2) {
        super(FEDInstruction.FEDType.MatrixIndexing, (Operator) null, cPOperand, cPOperand2, cPOperand7, str, str2);
        this.rowLower = cPOperand3;
        this.rowUpper = cPOperand4;
        this.colLower = cPOperand5;
        this.colUpper = cPOperand6;
    }

    protected IndexRange getIndexRange(ExecutionContext executionContext) {
        return new IndexRange((int) (executionContext.getScalarInput(this.rowLower).getLongValue() - 1), (int) (executionContext.getScalarInput(this.rowUpper).getLongValue() - 1), (int) (executionContext.getScalarInput(this.colLower).getLongValue() - 1), (int) (executionContext.getScalarInput(this.colUpper).getLongValue() - 1));
    }

    public static IndexingFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase(RightIndex.OPCODE)) {
            if (instructionPartsWithValueType.length != 7 && instructionPartsWithValueType.length != 8) {
                throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
            }
            CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
            CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
            CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[4]);
            CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[5]);
            CPOperand cPOperand6 = new CPOperand(instructionPartsWithValueType[6]);
            if (cPOperand.getDataType() == Types.DataType.MATRIX || cPOperand.getDataType() == Types.DataType.FRAME) {
                return new IndexingFEDInstruction(cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, str2, str);
            }
            throw new DMLRuntimeException("Can index only on matrices, frames in federated.");
        }
        if (!str2.equalsIgnoreCase(LeftIndex.OPCODE) && !str2.equalsIgnoreCase("mapLeftIndex")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str);
        }
        if (instructionPartsWithValueType.length != 8 && instructionPartsWithValueType.length != 9) {
            throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
        }
        CPOperand cPOperand7 = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand8 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand9 = new CPOperand(instructionPartsWithValueType[3]);
        CPOperand cPOperand10 = new CPOperand(instructionPartsWithValueType[4]);
        CPOperand cPOperand11 = new CPOperand(instructionPartsWithValueType[5]);
        CPOperand cPOperand12 = new CPOperand(instructionPartsWithValueType[6]);
        CPOperand cPOperand13 = new CPOperand(instructionPartsWithValueType[7]);
        if (cPOperand7.getDataType() == Types.DataType.MATRIX || cPOperand7.getDataType() == Types.DataType.FRAME || cPOperand8.getDataType() == Types.DataType.MATRIX || cPOperand8.getDataType() == Types.DataType.FRAME) {
            return new IndexingFEDInstruction(cPOperand7, cPOperand8, cPOperand9, cPOperand10, cPOperand11, cPOperand12, cPOperand13, str2, str);
        }
        throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (getOpcode().equalsIgnoreCase(RightIndex.OPCODE)) {
            rightIndexing(executionContext);
        } else {
            leftIndexing(executionContext);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v100, types: [java.util.List] */
    private void rightIndexing(ExecutionContext executionContext) {
        CacheableData<?> cacheableData = executionContext.getCacheableData(this.input1);
        IndexRange indexRange = getIndexRange(executionContext);
        FederationMap filter = cacheableData.getFedMapping().filter(indexRange);
        String[] strArr = new String[filter.getSize()];
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Iterator<Pair<FederatedRange, FederatedData>> it = filter.getMap().iterator();
        while (it.hasNext()) {
            FederatedRange federatedRange = (FederatedRange) it.next().getKey();
            long j = federatedRange.getBeginDims()[0];
            long j2 = federatedRange.getEndDims()[0];
            long j3 = federatedRange.getBeginDims()[1];
            long j4 = federatedRange.getEndDims()[1];
            long j5 = indexRange.rowStart >= j ? indexRange.rowStart - j : 0L;
            long j6 = (indexRange.rowEnd < j || indexRange.rowEnd >= j2) ? (j2 - j) - 1 : indexRange.rowEnd - j;
            long j7 = indexRange.colStart >= j3 ? indexRange.colStart - j3 : 0L;
            long j8 = (indexRange.colEnd < j3 || indexRange.colEnd >= j4) ? (j4 - j3) - 1 : indexRange.colEnd - j3;
            federatedRange.setBeginDim(0, Math.max(j - indexRange.rowStart, 0L));
            federatedRange.setBeginDim(1, Math.max(j3 - indexRange.colStart, 0L));
            federatedRange.setEndDim(0, indexRange.rowEnd >= j2 ? j2 - indexRange.rowStart : (indexRange.rowEnd - indexRange.rowStart) + 1);
            federatedRange.setEndDim(1, indexRange.colEnd >= j4 ? j4 - indexRange.colStart : (indexRange.colEnd - indexRange.colStart) + 1);
            strArr[i] = modifyIndices(new long[]{j5, j6, j7, j8}, 3, 7);
            if (this.input1.isFrame()) {
                if (cacheableData.isFederated(FTypes.FType.ROW)) {
                    arrayList = Arrays.asList(((FrameObject) cacheableData).getSchema((int) j7, (int) j8));
                } else {
                    Collections.addAll(arrayList, ((FrameObject) cacheableData).getSchema((int) j7, (int) j8));
                }
            }
            i++;
        }
        long nextFedDataID = FederationUtils.getNextFedDataID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, cacheableData.getMetaData().getDataCharacteristics(), cacheableData.getDataType());
        FederatedRequest[] callInstruction = FederationUtils.callInstruction(strArr, this.output, nextFedDataID, new CPOperand[]{this.input1}, new long[]{filter.getID()}, InstructionUtils.getExecType(this.instString));
        filter.execute(getTID(), true, federatedRequest);
        filter.execute(getTID(), true, callInstruction, new FederatedRequest[0]);
        if (!this.input1.isFrame()) {
            MatrixObject matrixObject = executionContext.getMatrixObject(this.output);
            matrixObject.getDataCharacteristics().set(filter.getMaxIndexInRange(0), filter.getMaxIndexInRange(1), (int) ((MatrixObject) cacheableData).getBlocksize());
            matrixObject.setFedMapping(filter.copyWithNewID(callInstruction[0].getID()));
        } else {
            FrameObject frameObject = executionContext.getFrameObject(this.output);
            frameObject.setSchema((Types.ValueType[]) arrayList.toArray(new Types.ValueType[0]));
            frameObject.getDataCharacteristics().setDimension(filter.getMaxIndexInRange(0), filter.getMaxIndexInRange(1));
            frameObject.setFedMapping(filter.copyWithNewID(callInstruction[0].getID()));
        }
    }

    /* JADX WARN: Type inference failed for: r0v48, types: [java.lang.Object[], int[]] */
    private void leftIndexing(ExecutionContext executionContext) {
        char c;
        CacheableData<?> cacheableData = executionContext.getCacheableData(this.input1);
        CacheableData<?> cacheableData2 = null;
        ScalarObject scalarObject = null;
        IndexRange indexRange = getIndexRange(executionContext);
        if (indexRange.rowStart < 0 || indexRange.rowStart >= cacheableData.getNumRows() || indexRange.rowEnd >= cacheableData.getNumRows() || indexRange.colStart < 0 || indexRange.colStart >= cacheableData.getNumColumns() || indexRange.colEnd >= cacheableData.getNumColumns()) {
            long j = indexRange.rowStart + 1;
            long j2 = indexRange.rowEnd + 1;
            long j3 = indexRange.colStart + 1;
            long j4 = indexRange.colEnd + 1;
            cacheableData.getNumRows();
            cacheableData.getNumColumns();
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Invalid values for matrix indexing: [" + j + ":" + dMLRuntimeException + "," + j2 + ":" + dMLRuntimeException + "] must be within matrix dimensions [" + j3 + "," + dMLRuntimeException + "].");
            throw dMLRuntimeException;
        }
        if (this.input2.getDataType() != Types.DataType.SCALAR) {
            cacheableData2 = executionContext.getCacheableData(this.input2);
            if ((indexRange.rowEnd - indexRange.rowStart) + 1 != cacheableData2.getNumRows() || (indexRange.colEnd - indexRange.colStart) + 1 != cacheableData2.getNumColumns()) {
                long numRows = cacheableData2.getNumRows();
                long numColumns = cacheableData2.getNumColumns();
                long j5 = indexRange.rowStart + 1;
                long j6 = indexRange.rowEnd + 1;
                long j7 = indexRange.colStart + 1;
                long j8 = indexRange.colEnd + 1;
                DMLRuntimeException dMLRuntimeException2 = new DMLRuntimeException("Invalid values for matrix indexing: dimensions of the source matrix [" + numRows + "x" + dMLRuntimeException2 + "] do not match the shape of the matrix specified by indices [" + numColumns + ":" + dMLRuntimeException2 + ", " + j5 + ":" + dMLRuntimeException2 + "].");
                throw dMLRuntimeException2;
            }
        } else {
            if (!indexRange.isScalar()) {
                throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + indexRange.toString() + ".");
            }
            scalarObject = executionContext.getScalarInput(this.input2);
        }
        FederationMap fedMapping = cacheableData.getFedMapping();
        String[] strArr = new String[fedMapping.getSize()];
        ?? r0 = new int[fedMapping.getSize()];
        FederatedRange[] federatedRangeArr = new FederatedRange[fedMapping.getSize()];
        int size = fedMapping.getSize();
        String createCopyInstString = createCopyInstString();
        int i = 0;
        int i2 = 0;
        int size2 = fedMapping.getSize();
        Iterator<Pair<FederatedRange, FederatedData>> it = fedMapping.getMap().iterator();
        while (it.hasNext()) {
            FederatedRange federatedRange = (FederatedRange) it.next().getKey();
            long j9 = federatedRange.getBeginDims()[0];
            long j10 = federatedRange.getEndDims()[0];
            long j11 = federatedRange.getBeginDims()[1];
            long j12 = federatedRange.getEndDims()[1];
            long j13 = indexRange.rowStart >= j9 ? indexRange.rowStart - j9 : 0L;
            long j14 = (indexRange.rowEnd < j9 || indexRange.rowEnd >= j10) ? (j10 - j9) - 1 : indexRange.rowEnd - j9;
            long j15 = indexRange.colStart >= j11 ? indexRange.colStart - j11 : 0L;
            long j16 = (indexRange.colEnd < j11 || indexRange.colEnd >= j12) ? (j12 - j11) - 1 : indexRange.colEnd - j11;
            long[] jArr = {(int) j13, (int) j14, (int) j15, (int) j16};
            if (cacheableData2 != null) {
                if (cacheableData.isFederated(FTypes.FType.ROW)) {
                    long j17 = (i2 + j14) - j13;
                    if (j17 >= 0 && j17 < cacheableData2.getNumRows() && indexRange.rowStart <= j10) {
                        int[] iArr = new int[4];
                        iArr[0] = i2;
                        iArr[1] = (int) j17;
                        iArr[2] = 0;
                        iArr[3] = ((int) cacheableData2.getNumColumns()) - 1;
                        r0[i] = iArr;
                        i2 = (int) (j17 + 1);
                        c = '\b';
                        strArr[i] = modifyIndices(jArr, 4, 8);
                        federatedRangeArr[i] = federatedRange;
                        size2 = Math.min(i, size2);
                    }
                }
                if (cacheableData.isFederated(FTypes.FType.COL)) {
                    char c2 = c;
                    if ((i2 + j16) - j15 >= 0 && c2 < cacheableData2.getNumColumns() && indexRange.colStart <= j12) {
                        int[] iArr2 = new int[4];
                        iArr2[0] = 0;
                        iArr2[1] = ((int) cacheableData2.getNumRows()) - 1;
                        iArr2[2] = i2;
                        iArr2[3] = c2;
                        r0[i] = iArr2;
                        i2 = (int) (c2 + 1);
                        c = '\b';
                        strArr[i] = modifyIndices(jArr, 4, 8);
                        federatedRangeArr[i] = federatedRange;
                        size2 = Math.min(i, size2);
                    }
                }
                size = Math.min(i, size);
                strArr[i] = createCopyInstString;
            } else if (indexRange.rowStart < j9 || indexRange.rowEnd >= j10 || indexRange.colStart < j11 || indexRange.colEnd >= j12) {
                size = Math.min(i, size);
                strArr[i] = createCopyInstString;
            } else {
                c = '\b';
                strArr[i] = modifyIndices(jArr, 4, 8);
                strArr[i] = changeScalarLiteralFlag(strArr[i], 3);
                federatedRangeArr[i] = federatedRange;
                size2 = Math.min(i, size2);
            }
            i++;
        }
        int[][] iArr3 = (int[][]) Arrays.stream((Object[]) r0).filter((v0) -> {
            return Objects.nonNull(v0);
        }).toArray(i3 -> {
            return new int[i3];
        });
        long nextFedDataID = FederationUtils.getNextFedDataID();
        fedMapping.execute(getTID(), true, new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, new MatrixCharacteristics(-1L, -1L), cacheableData.getDataType()));
        if (cacheableData2 != null) {
            FederatedRequest[] broadcastSliced = fedMapping.broadcastSliced(cacheableData2, DMLScript.LINEAGE ? executionContext.getLineageItem(this.input2) : null, this.input2.isFrame(), iArr3);
            FederatedRequest[] callInstruction = FederationUtils.callInstruction(strArr, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{fedMapping.getID(), broadcastSliced[0].getID()}, null);
            FederatedRequest cleanup = fedMapping.cleanup(getTID(), broadcastSliced[0].getID());
            if (iArr3.length == fedMapping.getSize()) {
                fedMapping.execute(getTID(), true, callInstruction, broadcastSliced, cleanup);
            } else {
                fedMapping.execute(getTID(), true, federatedRangeArr, callInstruction[size], (FederatedRequest[]) Arrays.copyOfRange(callInstruction, size2, size2 + iArr3.length), broadcastSliced, cleanup);
            }
        } else {
            FederatedRequest broadcast = fedMapping.broadcast(scalarObject);
            FederatedRequest[] callInstruction2 = FederationUtils.callInstruction(strArr, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{fedMapping.getID(), broadcast.getID()}, null);
            FederatedRequest cleanup2 = fedMapping.cleanup(getTID(), broadcast.getID());
            if (callInstruction2.length == 1) {
                fedMapping.execute(getTID(), true, callInstruction2, broadcast, cleanup2);
            } else {
                fedMapping.execute(getTID(), true, federatedRangeArr, callInstruction2[size], callInstruction2[size2], broadcast, cleanup2);
            }
        }
        if (!this.input1.isFrame()) {
            MatrixObject matrixObject = executionContext.getMatrixObject(this.output);
            matrixObject.getDataCharacteristics().set(cacheableData.getDataCharacteristics());
            matrixObject.setFedMapping(fedMapping.copyWithNewID(nextFedDataID));
        } else {
            FrameObject frameObject = executionContext.getFrameObject(this.output);
            frameObject.setSchema(((FrameObject) cacheableData).getSchema());
            frameObject.getDataCharacteristics().set(cacheableData.getDataCharacteristics());
            frameObject.setFedMapping(fedMapping.copyWithNewID(nextFedDataID));
        }
    }

    private String modifyIndices(long[] jArr, int i, int i2) {
        String[] split = this.instString.split("°");
        for (int i3 = i; i3 < i2; i3++) {
            split[i3] = InstructionUtils.createLiteralOperand(String.valueOf(jArr[i3 - i] + 1), Types.ValueType.INT64);
        }
        return String.join("°", split);
    }

    private String changeScalarLiteralFlag(String str, int i) {
        String[] split = str.split("°");
        split[i] = split[i].replace("true", "false");
        return String.join("°", split);
    }

    private String createCopyInstString() {
        String[] split = this.instString.split("°");
        return VariableCPInstruction.prepareCopyInstruction(split[2], split[8]).toString();
    }
}
