package org.apache.sysds.hops.fedplanner;

import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;

/* loaded from: input_file:org/apache/sysds/hops/fedplanner/AFederatedPlanner.class */
public abstract class AFederatedPlanner {
    public abstract void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo);

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean allowsFederated(Hop hop, Map<Long, FTypes.FType> map) {
        FTypes.FType[] fTypeArr = new FTypes.FType[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            fTypeArr[i] = map.get(Long.valueOf(hop.getInput(i).getHopID()));
        }
        return allowsFederated(hop, fTypeArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean allowsFederated(Hop hop, FTypes.FType[] fTypeArr) {
        if (hop instanceof AggBinaryOp) {
            return (fTypeArr[0] != null && fTypeArr[1] == null) || (fTypeArr[0] == null && fTypeArr[1] != null) || (fTypeArr[0] == FTypes.FType.COL && fTypeArr[1] == FTypes.FType.ROW);
        }
        if ((hop instanceof BinaryOp) && !hop.getDataType().isScalar()) {
            return (fTypeArr[0] != null && fTypeArr[1] == null) || (fTypeArr[0] == null && fTypeArr[1] != null) || (fTypeArr[0] != null && fTypeArr[0] == fTypeArr[1]);
        }
        if ((hop instanceof TernaryOp) && !hop.getDataType().isScalar()) {
            return (fTypeArr[0] == null && fTypeArr[1] == null && fTypeArr[2] == null) ? false : true;
        }
        if (HopRewriteUtils.isReorg(hop, Types.ReOrgOp.TRANS)) {
            return fTypeArr[0] == FTypes.FType.COL || fTypeArr[0] == FTypes.FType.ROW;
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD)) {
            return true;
        }
        if (fTypeArr.length != 1 || fTypeArr[0] == null) {
            return false;
        }
        return HopRewriteUtils.isReorg(hop, Types.ReOrgOp.TRANS) || HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.MIN, Types.AggOp.MAX);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FTypes.FType getFederatedOut(Hop hop, Map<Long, FTypes.FType> map) {
        FTypes.FType[] fTypeArr = new FTypes.FType[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            fTypeArr[i] = map.get(Long.valueOf(hop.getInput(i).getHopID()));
        }
        return getFederatedOut(hop, fTypeArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FTypes.FType getFederatedOut(Hop hop, FTypes.FType[] fTypeArr) {
        if (hop.isScalar()) {
            return null;
        }
        if (hop instanceof AggBinaryOp) {
            MMTSJ.MMTSJType checkTransposeSelf = ((AggBinaryOp) hop).checkTransposeSelf();
            if (checkTransposeSelf != MMTSJ.MMTSJType.NONE && ((checkTransposeSelf.isLeft() && fTypeArr[0] == FTypes.FType.ROW) || (checkTransposeSelf.isRight() && fTypeArr[0] == FTypes.FType.COL))) {
                return FTypes.FType.BROADCAST;
            }
            if (fTypeArr[0] == null || fTypeArr[0] != FTypes.FType.ROW) {
                return null;
            }
            return FTypes.FType.ROW;
        }
        if (hop instanceof BinaryOp) {
            return fTypeArr[0] != null ? fTypeArr[0] : fTypeArr[1];
        }
        if (hop instanceof TernaryOp) {
            return fTypeArr[0] != null ? fTypeArr[0] : fTypeArr[1] != null ? fTypeArr[1] : fTypeArr[2];
        }
        if (HopRewriteUtils.isReorg(hop, Types.ReOrgOp.TRANS)) {
            if (fTypeArr[0] == FTypes.FType.ROW) {
                return FTypes.FType.COL;
            }
            if (fTypeArr[0] == FTypes.FType.COL) {
                return FTypes.FType.ROW;
            }
            return null;
        }
        if (!(hop instanceof AggUnaryOp)) {
            if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)) {
                return deriveFType((DataOp) hop);
            }
            if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD)) {
                return fTypeArr[0];
            }
            return null;
        }
        boolean isCol = ((AggUnaryOp) hop).getDirection().isCol();
        if (fTypeArr[0] == FTypes.FType.ROW && isCol) {
            return null;
        }
        if (fTypeArr[0] == FTypes.FType.COL && !isCol) {
            return null;
        }
        if (fTypeArr[0] == FTypes.FType.ROW || fTypeArr[0] == FTypes.FType.COL) {
            return fTypeArr[0];
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FTypes.FType deriveFType(DataOp dataOp) {
        Hop input = dataOp.getInput(dataOp.getParameterIndex(DataExpression.FED_RANGES));
        boolean z = true;
        boolean z2 = true;
        for (int i = 0; i < input.getInput().size() / 2; i++) {
            Hop input2 = input.getInput(2 * i);
            Hop input3 = input.getInput((2 * i) + 1);
            long intValueSafe = HopRewriteUtils.getIntValueSafe(input2.getInput(0));
            long intValueSafe2 = HopRewriteUtils.getIntValueSafe(input3.getInput(0));
            z &= HopRewriteUtils.getIntValueSafe(input3.getInput(1)) - HopRewriteUtils.getIntValueSafe(input2.getInput(1)) == dataOp.getDim2();
            z2 &= intValueSafe2 - intValueSafe == dataOp.getDim1();
        }
        return (z && z2) ? FTypes.FType.FULL : z ? FTypes.FType.ROW : z2 ? FTypes.FType.COL : FTypes.FType.OTHER;
    }
}
