package org.nd4j.autodiff.samediff.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

/* loaded from: input_file:org/nd4j/autodiff/samediff/transform/SubGraphPredicate.class */
public class SubGraphPredicate extends OpPredicate {
    protected final OpPredicate root;
    protected Integer inputCount = null;
    protected Integer outputCount = null;
    protected Map<Integer, OpPredicate> opInputMatchPredicates = new HashMap();
    protected Map<Integer, OpPredicate> opInputSubgraphPredicates = new HashMap();

    protected SubGraphPredicate(OpPredicate opPredicate) {
        this.root = opPredicate;
    }

    @Override // org.nd4j.autodiff.samediff.transform.OpPredicate
    public boolean matches(SameDiff sameDiff, DifferentialFunction differentialFunction) {
        DifferentialFunction variableOutputOp;
        if (!this.root.matches(sameDiff, differentialFunction)) {
            return false;
        }
        SDVariable[] args = differentialFunction.args();
        int length = args == null ? 0 : args.length;
        if (this.inputCount != null && length != this.inputCount.intValue()) {
            return false;
        }
        SDVariable[] outputVariables = differentialFunction.outputVariables();
        int length2 = outputVariables == null ? 0 : outputVariables.length;
        if (this.outputCount != null && length2 != this.outputCount.intValue()) {
            return false;
        }
        Iterator it = Arrays.asList(this.opInputMatchPredicates, this.opInputSubgraphPredicates).iterator();
        while (it.hasNext()) {
            for (Map.Entry entry : ((Map) it.next()).entrySet()) {
                int intValue = ((Integer) entry.getKey()).intValue();
                if (intValue >= length || (variableOutputOp = sameDiff.getVariableOutputOp(args[intValue].name())) == null || !((OpPredicate) entry.getValue()).matches(sameDiff, variableOutputOp)) {
                    return false;
                }
            }
        }
        return true;
    }

    public SubGraph getSubGraph(SameDiff sameDiff, DifferentialFunction differentialFunction) {
        Preconditions.checkState(matches(sameDiff, differentialFunction), "Root function does not match predicate");
        ArrayList arrayList = new ArrayList();
        if (!this.opInputSubgraphPredicates.isEmpty()) {
            for (Map.Entry<Integer, OpPredicate> entry : this.opInputSubgraphPredicates.entrySet()) {
                OpPredicate value = entry.getValue();
                DifferentialFunction variableOutputOp = sameDiff.getVariableOutputOp(differentialFunction.arg(entry.getKey().intValue()).name());
                if (variableOutputOp != null) {
                    arrayList.add(variableOutputOp);
                    if (value instanceof SubGraphPredicate) {
                        arrayList.addAll(((SubGraphPredicate) value).getSubGraph(sameDiff, variableOutputOp).childNodes);
                    }
                }
            }
        }
        return SubGraph.builder().sameDiff(sameDiff).rootNode(differentialFunction).childNodes(arrayList).build();
    }

    public static SubGraphPredicate withRoot(@NonNull OpPredicate opPredicate) {
        if (opPredicate == null) {
            throw new NullPointerException("root is marked @NonNull but is null");
        }
        return new SubGraphPredicate(opPredicate);
    }

    public SubGraphPredicate withInputCount(int i) {
        this.inputCount = Integer.valueOf(i);
        return this;
    }

    public SubGraphPredicate withOutputCount(int i) {
        this.outputCount = Integer.valueOf(i);
        return this;
    }

    public SubGraphPredicate withInputMatching(int i, @NonNull OpPredicate opPredicate) {
        if (opPredicate == null) {
            throw new NullPointerException("opPredicate is marked @NonNull but is null");
        }
        this.opInputMatchPredicates.put(Integer.valueOf(i), opPredicate);
        return this;
    }

    public SubGraphPredicate withInputSubgraph(int i, @NonNull OpPredicate opPredicate) {
        if (opPredicate == null) {
            throw new NullPointerException("opPredicate is marked @NonNull but is null");
        }
        this.opInputSubgraphPredicates.put(Integer.valueOf(i), opPredicate);
        return this;
    }
}
