package org.apache.sysds.hops.ipa;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/hops/ipa/FunctionCallSizeInfo.class */
public class FunctionCallSizeInfo {
    private final FunctionCallGraph _fgraph;
    private final Set<String> _fcand;
    private final Set<String> _fcandUnary;
    private final Map<String, Set<Integer>> _fcandSafeNNZ;
    private final Map<String, Set<Integer>> _fSafeLiterals;

    public FunctionCallSizeInfo(FunctionCallGraph functionCallGraph) {
        this(functionCallGraph, true);
    }

    public FunctionCallSizeInfo(FunctionCallGraph functionCallGraph, boolean z) {
        this._fgraph = functionCallGraph;
        this._fcand = new HashSet();
        this._fcandUnary = new HashSet();
        this._fcandSafeNNZ = new HashMap();
        this._fSafeLiterals = new HashMap();
        constructFunctionCallSizeInfo();
    }

    public int getFunctionCallCount(String str) {
        return this._fgraph.getFunctionCalls(str).size();
    }

    public boolean isValidFunction(String str) {
        return this._fcand.contains(str);
    }

    public Set<String> getValidFunctions() {
        return this._fcand;
    }

    public Set<String> getInvalidFunctions() {
        return this._fgraph.getReachableFunctions(getValidFunctions());
    }

    public void addDimsPreservingFunction(String str) {
        this._fcandUnary.add(str);
    }

    public Set<String> getDimsPreservingFunctions() {
        return this._fcandUnary;
    }

    public boolean isDimsPreservingFunction(String str) {
        return this._fcandUnary.contains(str);
    }

    public boolean isSafeNnz(String str, int i) {
        return this._fcandSafeNNZ.containsKey(str) && this._fcandSafeNNZ.get(str).contains(Integer.valueOf(i));
    }

    public boolean hasSafeLiterals(String str) {
        return this._fSafeLiterals.containsKey(str) && !this._fSafeLiterals.get(str).isEmpty();
    }

    public boolean isSafeLiteral(String str, int i) {
        return this._fSafeLiterals.containsKey(str) && this._fSafeLiterals.get(str).contains(Integer.valueOf(i));
    }

    private void constructFunctionCallSizeInfo() {
        for (String str : this._fgraph.getReachableFunctions()) {
            List<FunctionOp> functionCalls = this._fgraph.getFunctionCalls(str);
            if (functionCalls != null && !functionCalls.isEmpty()) {
                if (functionCalls.size() == 1) {
                    this._fcand.add(str);
                } else {
                    FunctionOp functionOp = functionCalls.get(0);
                    boolean z = true;
                    for (int i = 1; i < functionCalls.size(); i++) {
                        FunctionOp functionOp2 = functionCalls.get(i);
                        for (int i2 = 0; i2 < functionOp.getInput().size(); i2++) {
                            Hop hop = functionOp.getInput().get(i2);
                            Hop hop2 = functionOp2.getInput().get(i2);
                            z &= hop.dimsKnown() && hop2.dimsKnown() && hop.getDim1() == hop2.getDim1() && hop.getDim2() == hop2.getDim2() && hop.getNnz() == hop2.getNnz();
                            if ((hop instanceof LiteralOp) || (hop2 instanceof LiteralOp)) {
                                z &= (hop instanceof LiteralOp) && (hop2 instanceof LiteralOp) && HopRewriteUtils.isEqualValue((LiteralOp) hop, (LiteralOp) hop2);
                            }
                        }
                    }
                    if (z) {
                        this._fcand.add(str);
                    }
                }
            }
        }
        for (String str2 : this._fcand) {
            List<FunctionOp> functionCalls2 = this._fgraph.getFunctionCalls(str2);
            if (functionCalls2 != null && !functionCalls2.isEmpty()) {
                FunctionOp functionOp3 = functionCalls2.get(0);
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < functionOp3.getInput().size(); i3++) {
                    if (functionOp3.getInput().get(0).getNnz() >= 0) {
                        hashSet.add(Integer.valueOf(i3));
                    }
                }
                this._fcandSafeNNZ.put(str2, hashSet);
            }
        }
        for (String str3 : this._fgraph.getReachableFunctions()) {
            List<FunctionOp> functionCalls3 = this._fgraph.getFunctionCalls(str3);
            if (functionCalls3 != null && !functionCalls3.isEmpty()) {
                FunctionOp functionOp4 = functionCalls3.get(0);
                HashSet hashSet2 = new HashSet();
                for (int i4 = 0; i4 < functionOp4.getInput().size(); i4++) {
                    if (functionOp4.getInput().get(i4) instanceof LiteralOp) {
                        hashSet2.add(Integer.valueOf(i4));
                    }
                }
                for (int i5 = 1; i5 < functionCalls3.size(); i5++) {
                    FunctionOp functionOp5 = functionCalls3.get(i5);
                    for (int i6 = 0; i6 < functionOp4.getInput().size(); i6++) {
                        if (hashSet2.contains(Integer.valueOf(i6))) {
                            Hop hop3 = functionOp4.getInput().get(i6);
                            Hop hop4 = functionOp5.getInput().get(i6);
                            if (!(hop4 instanceof LiteralOp) || !HopRewriteUtils.isEqualValue((LiteralOp) hop3, (LiteralOp) hop4)) {
                                hashSet2.remove(Integer.valueOf(i6));
                            }
                        }
                    }
                }
                this._fSafeLiterals.put(str3, hashSet2);
            }
        }
    }

    public int hashCode() {
        return Arrays.hashCode(new int[]{this._fgraph.hashCode(), this._fcand.hashCode(), this._fcandUnary.hashCode(), this._fcandSafeNNZ.hashCode(), this._fSafeLiterals.hashCode()});
    }

    public boolean equals(Object obj) {
        if (obj instanceof FunctionCallSizeInfo) {
            return false;
        }
        FunctionCallSizeInfo functionCallSizeInfo = (FunctionCallSizeInfo) obj;
        return this._fgraph == functionCallSizeInfo._fgraph && this._fcand.equals(functionCallSizeInfo._fcand) && this._fcandUnary.equals(functionCallSizeInfo._fcandUnary) && this._fcandSafeNNZ.entrySet().equals(functionCallSizeInfo._fcandSafeNNZ.entrySet()) && this._fSafeLiterals.entrySet().equals(functionCallSizeInfo._fSafeLiterals.entrySet());
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Valid functions for propagation: \n");
        for (String str : getValidFunctions()) {
            sb.append("--");
            sb.append(str);
            sb.append(": ");
            sb.append(getFunctionCallCount(str));
            if (!this._fcandSafeNNZ.get(str).isEmpty()) {
                sb.append("\n----");
                sb.append(Arrays.toString(this._fcandSafeNNZ.get(str).toArray(new Integer[0])));
            }
            sb.append(ProgramConverter.NEWLINE);
        }
        if (!getInvalidFunctions().isEmpty()) {
            sb.append("Invalid functions for propagation: \n");
            for (String str2 : getInvalidFunctions()) {
                sb.append("--");
                sb.append(str2);
                sb.append(": ");
                sb.append(getFunctionCallCount(str2));
                sb.append(ProgramConverter.NEWLINE);
            }
        }
        if (!getDimsPreservingFunctions().isEmpty()) {
            sb.append("Dimensions-preserving functions: \n");
            for (String str3 : getDimsPreservingFunctions()) {
                sb.append("--");
                sb.append(str3);
                sb.append(": ");
                sb.append(getFunctionCallCount(str3));
                sb.append(ProgramConverter.NEWLINE);
            }
        }
        sb.append("Valid scalars for propagation: \n");
        for (Map.Entry<String, Set<Integer>> entry : this._fSafeLiterals.entrySet()) {
            sb.append("--");
            sb.append(entry.getKey());
            sb.append(": ");
            for (Integer num : entry.getValue()) {
                sb.append(num);
                sb.append(":");
                sb.append(this._fgraph.getFunctionCalls(entry.getKey()).get(0).getInput().get(num.intValue()).getName());
                sb.append(" ");
            }
            sb.append(ProgramConverter.NEWLINE);
        }
        sb.append("Valid #non-zeros for propagation: \n");
        for (Map.Entry<String, Set<Integer>> entry2 : this._fcandSafeNNZ.entrySet()) {
            sb.append("--");
            sb.append(entry2.getKey());
            sb.append(": ");
            for (Integer num2 : entry2.getValue()) {
                sb.append(num2);
                sb.append(":");
                sb.append(this._fgraph.getFunctionCalls(entry2.getKey()).get(0).getInput().get(num2.intValue()).getName());
                sb.append(" ");
            }
            sb.append(ProgramConverter.NEWLINE);
        }
        return sb.toString();
    }
}
