package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.class */
public class ReshapeFEDInstruction extends UnaryFEDInstruction {
    private final CPOperand _opRows;
    private final CPOperand _opCols;
    private final CPOperand _opDims;
    private final CPOperand _opByRow;

    private ReshapeFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, String str, String str2) {
        super(FEDInstruction.FEDType.Reshape, operator, cPOperand, cPOperand6, str, str2);
        this._opRows = cPOperand2;
        this._opCols = cPOperand3;
        this._opDims = cPOperand4;
        this._opByRow = cPOperand5;
    }

    public static ReshapeFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 6);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[4]);
        CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[5]);
        CPOperand cPOperand6 = new CPOperand(instructionPartsWithValueType[6]);
        if (str2.equalsIgnoreCase("rshape")) {
            return new ReshapeFEDInstruction(new Operator(true), cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, cPOperand6, str2, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing an ReshapeInstruction: " + str);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (this.output.getDataType() != Types.DataType.MATRIX) {
            throw new DMLRuntimeException("Federated Reshape Instruction only supports matrix as output.");
        }
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        BooleanObject booleanObject = (BooleanObject) executionContext.getScalarInput(this._opByRow.getName(), Types.ValueType.BOOLEAN, this._opByRow.isLiteral());
        int longValue = (int) executionContext.getScalarInput(this._opRows).getLongValue();
        int longValue2 = (int) executionContext.getScalarInput(this._opCols).getLongValue();
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Federated Rshape: Federated input expected, but invoked w/ " + matrixObject.isFederated());
        }
        if (matrixObject.getNumColumns() * matrixObject.getNumRows() != longValue * longValue2) {
            throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells (" + matrixObject.getNumRows() + ":" + matrixObject.getNumColumns() + ", " + longValue + ":" + longValue2 + ").");
        }
        if (((List) Arrays.stream(matrixObject.getFedMapping().getFederatedRanges()).map(federatedRange -> {
            return Boolean.valueOf(federatedRange.getSize() % ((long) (booleanObject.getBooleanValue() ? longValue2 : longValue)) == 0);
        }).collect(Collectors.toList())).contains(false)) {
            throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells for each worker.");
        }
        FederatedRequest[] callInstruction = FederationUtils.callInstruction(getNewInstString(matrixObject, this.instString, longValue, longValue2, booleanObject.getBooleanValue()), this.output, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()});
        matrixObject.getFedMapping().execute(getTID(), true, callInstruction, new FederatedRequest[0]);
        FederationMap fedMapping = matrixObject.getFedMapping();
        int i = 0;
        while (i < fedMapping.getFederatedRanges().length) {
            long size = fedMapping.getFederatedRanges()[i].getSize();
            long j = booleanObject.getBooleanValue() ? size / longValue2 : longValue;
            long j2 = booleanObject.getBooleanValue() ? longValue2 : size / longValue;
            fedMapping.getFederatedRanges()[i].setBeginDim(0, (fedMapping.getFederatedRanges()[i].getBeginDims()[0] == 0 || i == 0) ? 0L : fedMapping.getFederatedRanges()[i - 1].getEndDims()[0]);
            fedMapping.getFederatedRanges()[i].setEndDim(0, fedMapping.getFederatedRanges()[i].getBeginDims()[0] + j);
            fedMapping.getFederatedRanges()[i].setBeginDim(1, (fedMapping.getFederatedRanges()[i].getBeginDims()[1] == 0 || i == 0) ? 0L : fedMapping.getFederatedRanges()[i - 1].getEndDims()[1]);
            fedMapping.getFederatedRanges()[i].setEndDim(1, fedMapping.getFederatedRanges()[i].getBeginDims()[1] + j2);
            i++;
        }
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.output);
        matrixObject2.getDataCharacteristics().set(longValue, longValue2, (int) matrixObject.getBlocksize(), matrixObject.getNnz());
        matrixObject2.setFedMapping(fedMapping.copyWithNewID(callInstruction[0].getID()));
    }

    private static String[] getNewInstString(MatrixObject matrixObject, String str, int i, int i2, boolean z) {
        String[] strArr = new String[matrixObject.getFedMapping().getSize()];
        int size = ((Set) Arrays.stream(matrixObject.getFedMapping().getFederatedRanges()).map((v0) -> {
            return v0.getSize();
        }).collect(Collectors.toSet())).size() == 1 ? 1 : matrixObject.getFedMapping().getSize();
        for (int i3 = 0; i3 < size; i3++) {
            String[] split = str.split("°");
            long size2 = matrixObject.getFedMapping().getFederatedRanges()[i3].getSize();
            String str2 = z ? split[3] : split[4];
            strArr[i3] = str.replace(str2, z ? str2.replace(String.valueOf(i), String.valueOf(size2 / i2)) : str2.replace(String.valueOf(i2), String.valueOf(size2 / i)));
        }
        if (size == 1) {
            Arrays.fill(strArr, strArr[0]);
        }
        return strArr;
    }

    @Override // org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction, org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return Pair.of(this.output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(executionContext, this.input1, this._opRows, this._opCols, this._opDims, this._opByRow)));
    }
}
