package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.class */
public class CentralMomentFEDInstruction extends AggregateUnaryFEDInstruction {

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction$CMFunction.class */
    private static class CMFunction extends FederatedUDF {
        private static final long serialVersionUID = 7460149207607220994L;
        private final CMOperator _op;

        public CMFunction(long j, CMOperator cMOperator) {
            super(new long[]{j});
            this._op = cMOperator;
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ((MatrixObject) dataArr[0]).acquireReadAndRelease().cmOperations(this._op));
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return null;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction$CMWeightsFunction.class */
    private static class CMWeightsFunction extends FederatedUDF {
        private static final long serialVersionUID = -3685746246551622021L;
        private final CMOperator _op;
        private final MatrixBlock _weights;

        protected CMWeightsFunction(long j, CMOperator cMOperator, MatrixBlock matrixBlock) {
            super(new long[]{j});
            this._op = cMOperator;
            this._weights = matrixBlock;
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ((MatrixObject) dataArr[0]).acquireReadAndRelease().cmOperations(this._op, this._weights));
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return null;
        }
    }

    private CentralMomentFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    public static CentralMomentFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        FEDInstruction.FederatedOutput valueOf = FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        CentralMomentFEDInstruction parseInstruction = parseInstruction(CentralMomentCPInstruction.parseInstruction(InstructionUtils.removeFEDOutputFlag(str)));
        parseInstruction._fedOut = valueOf;
        return parseInstruction;
    }

    public static CentralMomentFEDInstruction parseInstruction(Instruction instruction) {
        return instruction instanceof CentralMomentCPInstruction ? parseInstruction((CentralMomentCPInstruction) instruction) : instruction instanceof SPInstruction ? parseInstruction(CentralMomentCPInstruction.parseInstruction(instruction.getInstructionString())) : parseInstruction(instruction.getInstructionString());
    }

    public static CentralMomentFEDInstruction parseInstruction(CentralMomentCPInstruction centralMomentCPInstruction) {
        return new CentralMomentFEDInstruction(centralMomentCPInstruction.getOperator(), centralMomentCPInstruction.input1, centralMomentCPInstruction.input2, centralMomentCPInstruction.input3, centralMomentCPInstruction.output, centralMomentCPInstruction.getOpcode(), centralMomentCPInstruction.getInstructionString());
    }

    @Override // org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1.getName());
        ScalarObject scalarInput = executionContext.getScalarInput(this.input3 == null ? this.input2 : this.input3);
        CMOperator cMOperator = (CMOperator) this._optr;
        if (cMOperator.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID) {
            cMOperator = cMOperator.setCMAggOp((int) scalarInput.getLongValue());
        }
        FederationMap fedMapping = matrixObject.getFedMapping();
        ArrayList arrayList = new ArrayList();
        CMOperator cMOperator2 = cMOperator;
        fedMapping.mapParallel(FederationUtils.getNextFedDataID(), (federatedRange, federatedData) -> {
            FederatedResponse federatedResponse;
            try {
                if (this.input3 == null) {
                    federatedResponse = federatedData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new CMFunction(federatedData.getVarID(), cMOperator2))).get();
                } else {
                    federatedResponse = federatedData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new CMWeightsFunction(federatedData.getVarID(), cMOperator2, executionContext.getMatrixInput(this.input2.getName())))).get();
                }
                if (!federatedResponse.isSuccessful()) {
                    federatedResponse.throwExceptionFromResponse();
                }
                synchronized (arrayList) {
                    arrayList.add((CM_COV_Object) federatedResponse.getData()[0]);
                }
                return null;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        });
        try {
            executionContext.setScalarOutput(this.output.getName(), new DoubleObject(((CM_COV_Object) arrayList.stream().reduce((cM_COV_Object, cM_COV_Object2) -> {
                return (CM_COV_Object) cMOperator2.fn.execute(cM_COV_Object, cM_COV_Object2);
            }).get()).getRequiredResult(cMOperator2)));
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }
}
