package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageTraceable;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.class */
public class VariableFEDInstruction extends FEDInstruction implements LineageTraceable {
    private static final Log LOG = LogFactory.getLog(VariableFEDInstruction.class.getName());
    private final VariableCPInstruction _in;

    protected VariableFEDInstruction(VariableCPInstruction variableCPInstruction) {
        super(null, variableCPInstruction.getOperator(), variableCPInstruction.getOpcode(), variableCPInstruction.getInstructionString());
        this._in = variableCPInstruction;
    }

    public static VariableFEDInstruction parseInstruction(VariableCPInstruction variableCPInstruction, ExecutionContext executionContext) {
        if (variableCPInstruction.getVariableOpcode() == VariableCPInstruction.VariableOperationCode.Write && variableCPInstruction.getInput1().isMatrix() && variableCPInstruction.getInput3().getName().contains("federated")) {
            return parseInstruction(variableCPInstruction);
        }
        if (variableCPInstruction.getVariableOpcode() == VariableCPInstruction.VariableOperationCode.CastAsFrameVariable && variableCPInstruction.getInput1().isMatrix() && executionContext.getCacheableData(variableCPInstruction.getInput1()).isFederatedExcept(FTypes.FType.BROADCAST)) {
            return parseInstruction(variableCPInstruction);
        }
        if (variableCPInstruction.getVariableOpcode() == VariableCPInstruction.VariableOperationCode.CastAsMatrixVariable && variableCPInstruction.getInput1().isFrame() && executionContext.getCacheableData(variableCPInstruction.getInput1()).isFederatedExcept(FTypes.FType.BROADCAST)) {
            return parseInstruction(variableCPInstruction);
        }
        return null;
    }

    private static VariableFEDInstruction parseInstruction(VariableCPInstruction variableCPInstruction) {
        return new VariableFEDInstruction(variableCPInstruction);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        VariableCPInstruction.VariableOperationCode variableOpcode = this._in.getVariableOpcode();
        switch (variableOpcode) {
            case Write:
                processWriteInstruction(executionContext);
                return;
            case CastAsMatrixVariable:
                processCastAsMatrixVariableInstruction(executionContext);
                return;
            case CastAsFrameVariable:
                processCastAsFrameVariableInstruction(executionContext);
                return;
            default:
                throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + variableOpcode);
        }
    }

    private void processWriteInstruction(ExecutionContext executionContext) {
        LOG.warn("Processing write command federated");
        this._in.processInstruction(executionContext);
    }

    private void processCastAsMatrixVariableInstruction(ExecutionContext executionContext) {
        FrameObject frameObject = executionContext.getFrameObject(this._in.getInput1());
        if (!frameObject.isFederated()) {
            throw new DMLRuntimeException("Federated Reorg: Federated input expected, but invoked w/ " + frameObject.isFederated());
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this._in.getInstructionString(), this._in.getOutput(), new CPOperand[]{this._in.getInput1()}, new long[]{frameObject.getFedMapping().getID()});
        frameObject.getFedMapping().execute(getTID(), true, callInstruction);
        MatrixObject matrixObject = executionContext.getMatrixObject(this._in.getOutput());
        FederationMap copyWithNewID = frameObject.getFedMapping().copyWithNewID(callInstruction.getID());
        ArrayList arrayList = new ArrayList();
        for (Pair<FederatedRange, FederatedData> pair : copyWithNewID.getMap()) {
            FederatedData federatedData = (FederatedData) pair.getValue();
            arrayList.add(Pair.of((FederatedRange) pair.getKey(), new FederatedData(Types.DataType.MATRIX, federatedData.getAddress(), federatedData.getFilepath(), federatedData.getVarID())));
        }
        matrixObject.setFedMapping(copyWithNewID);
    }

    private void processCastAsFrameVariableInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this._in.getInput1());
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Federated Reorg: Federated input expected, but invoked w/ " + matrixObject.isFederated());
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this._in.getInstructionString(), this._in.getOutput(), new CPOperand[]{this._in.getInput1()}, new long[]{matrixObject.getFedMapping().getID()});
        matrixObject.getFedMapping().execute(getTID(), true, callInstruction);
        FrameObject frameObject = executionContext.getFrameObject(this._in.getOutput());
        frameObject.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject.getNumColumns(), matrixObject.getBlocksize(), matrixObject.getNnz());
        FederationMap copyWithNewID = matrixObject.getFedMapping().copyWithNewID(callInstruction.getID());
        ArrayList arrayList = new ArrayList();
        for (Map.Entry entry : copyWithNewID.getMap()) {
            FederatedData federatedData = (FederatedData) entry.getValue();
            arrayList.add(Pair.of((FederatedRange) entry.getKey(), new FederatedData(Types.DataType.FRAME, federatedData.getAddress(), federatedData.getFilepath(), federatedData.getVarID())));
        }
        Types.ValueType[] valueTypeArr = new Types.ValueType[(int) matrixObject.getDataCharacteristics().getCols()];
        Arrays.fill(valueTypeArr, Types.ValueType.FP64);
        frameObject.setSchema(valueTypeArr);
        frameObject.setFedMapping(copyWithNewID);
    }

    @Override // org.apache.sysds.runtime.lineage.LineageTraceable
    public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
        return this._in.getLineageItem(executionContext);
    }
}
