package org.apache.sysds.runtime.controlprogram.federated;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
import org.antlr.v4.runtime.tree.xpath.XPath;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederationUtils.class */
public class FederationUtils {
    protected static Logger log = Logger.getLogger(FederationUtils.class);
    private static final IDSequence _idSeq = new IDSequence();

    public static void resetFedDataID() {
        _idSeq.reset();
    }

    public static long getNextFedDataID() {
        return _idSeq.getNextID();
    }

    public static FederatedRequest callInstruction(String str, CPOperand cPOperand, CPOperand[] cPOperandArr, long[] jArr) {
        long nextFedDataID = getNextFedDataID();
        String replace = str.replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name()).replace("°" + cPOperand.getName() + "·", "°" + String.valueOf(nextFedDataID) + "·");
        for (int i = 0; i < cPOperandArr.length; i++) {
            if (cPOperandArr[i] != null) {
                replace = replace.replace("°" + cPOperandArr[i].getName() + "·", "°" + String.valueOf(jArr[i]) + "·").replace("=" + cPOperandArr[i].getName(), "=" + String.valueOf(jArr[i]));
            }
        }
        return new FederatedRequest(FederatedRequest.RequestType.EXEC_INST, nextFedDataID, replace);
    }

    public static MatrixBlock aggAdd(Future<FederatedResponse>[] futureArr) {
        try {
            SimpleOperator simpleOperator = new SimpleOperator(Plus.getPlusFnObject());
            MatrixBlock[] matrixBlockArr = new MatrixBlock[futureArr.length];
            for (int i = 0; i < futureArr.length; i++) {
                matrixBlockArr[i] = (MatrixBlock) futureArr[i].get().getData()[0];
            }
            return MatrixBlock.naryOperations(simpleOperator, matrixBlockArr, new ScalarObject[0], new MatrixBlock());
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static MatrixBlock aggMean(Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
        try {
            FederatedRange[] federatedRanges = federationMap.getFederatedRanges();
            BinaryOperator parseBinaryOperator = InstructionUtils.parseBinaryOperator("+");
            ScalarOperator parseScalarBinaryOperator = InstructionUtils.parseScalarBinaryOperator(XPath.WILDCARD, false);
            MatrixBlock matrixBlock = null;
            long j = 0;
            for (int i = 0; i < futureArr.length; i++) {
                MatrixBlock matrixBlock2 = (MatrixBlock) futureArr[i].get().getData()[0];
                j += federatedRanges[i].getSize(0);
                parseScalarBinaryOperator = parseScalarBinaryOperator.setConstant(federatedRanges[i].getSize(0));
                MatrixBlock scalarOperations = matrixBlock2.scalarOperations(parseScalarBinaryOperator, (MatrixValue) new MatrixBlock());
                matrixBlock = matrixBlock == null ? scalarOperations : matrixBlock.binaryOperationsInPlace(parseBinaryOperator, (MatrixValue) scalarOperations);
            }
            return matrixBlock.scalarOperations(InstructionUtils.parseScalarBinaryOperator(Lop.FILE_SEPARATOR, false).setConstant(j), (MatrixValue) new MatrixBlock());
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static DoubleObject aggMinMax(Future<FederatedResponse>[] futureArr, boolean z, boolean z2) {
        try {
            double d = z ? Double.MAX_VALUE : -1.7976931348623157E308d;
            for (Future<FederatedResponse> future : futureArr) {
                double doubleValue = z2 ? ((ScalarObject) future.get().getData()[0]).getDoubleValue() : z ? ((MatrixBlock) future.get().getData()[0]).min() : ((MatrixBlock) future.get().getData()[0]).max();
                d = z ? Math.min(d, doubleValue) : Math.max(d, doubleValue);
            }
            return new DoubleObject(d);
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static MatrixBlock[] getResults(Future<FederatedResponse>[] futureArr) {
        try {
            MatrixBlock[] matrixBlockArr = new MatrixBlock[futureArr.length];
            for (int i = 0; i < futureArr.length; i++) {
                matrixBlockArr[i] = (MatrixBlock) futureArr[i].get().getData()[0];
            }
            return matrixBlockArr;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static MatrixBlock rbind(Future<FederatedResponse>[] futureArr) {
        try {
            MatrixBlock[] results = getResults(futureArr);
            return results[0].append((MatrixBlock[]) Arrays.copyOfRange(results, 1, results.length), new MatrixBlock(), false);
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static ScalarObject aggScalar(AggregateUnaryOperator aggregateUnaryOperator, Future<FederatedResponse>[] futureArr) {
        if (!(aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanFunction) && (!(aggregateUnaryOperator.aggOp.increOp.fn instanceof Builtin) || (((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() != Builtin.BuiltinCode.MIN && ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() != Builtin.BuiltinCode.MAX))) {
            throw new DMLRuntimeException("Unsupported aggregation operator: " + aggregateUnaryOperator.aggOp.increOp.getClass().getSimpleName());
        }
        try {
            if (aggregateUnaryOperator.aggOp.increOp.fn instanceof Builtin) {
                return aggMinMax(futureArr, ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN, true);
            }
            double d = 0.0d;
            for (Future<FederatedResponse> future : futureArr) {
                d += ((ScalarObject) future.get().getData()[0]).getDoubleValue();
            }
            return new DoubleObject(d);
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static MatrixBlock aggMatrix(AggregateUnaryOperator aggregateUnaryOperator, Future<FederatedResponse>[] futureArr, FederationMap federationMap) {
        if (aggregateUnaryOperator.isRowAggregate()) {
            return rbind(futureArr);
        }
        if (aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanFunction) {
            return aggAdd(futureArr);
        }
        if (aggregateUnaryOperator.aggOp.increOp.fn instanceof Mean) {
            return aggMean(futureArr, federationMap);
        }
        if ((aggregateUnaryOperator.aggOp.increOp.fn instanceof Builtin) && (((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN || ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MAX)) {
            return new MatrixBlock(1, 1, aggMinMax(futureArr, ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN, false).getDoubleValue());
        }
        throw new DMLRuntimeException("Unsupported aggregation operator: " + aggregateUnaryOperator.aggOp.increOp.fn.getClass().getSimpleName());
    }

    public static void waitFor(List<Future<FederatedResponse>> list) {
        try {
            Iterator<Future<FederatedResponse>> it = list.iterator();
            while (it.hasNext()) {
                it.next().get();
            }
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static FederationMap federateLocalData(CacheableData<?> cacheableData) {
        long nextFedDataID = getNextFedDataID();
        FederatedLocalData federatedLocalData = new FederatedLocalData(nextFedDataID, cacheableData);
        HashMap hashMap = new HashMap();
        hashMap.put(new FederatedRange(new long[2], new long[]{cacheableData.getNumRows(), cacheableData.getNumColumns()}), federatedLocalData);
        return new FederationMap(nextFedDataID, hashMap);
    }
}
