package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.class */
public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
    private static final Comparator<Hop> compareByDataType;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.sysds.hops.rewrite.RewriteElementwiseMultChainOptimization$2, reason: invalid class name */
    /* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$sysds$common$Types$DataType = new int[Types.DataType.values().length];

        static {
            try {
                $SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.SCALAR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.MATRIX.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.TENSOR.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.FRAME.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.LIST.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        if (arrayList == null) {
            return null;
        }
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList.set(i, rule_RewriteEMult(arrayList.get(i)));
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return null;
        }
        return rule_RewriteEMult(hop);
    }

    private static boolean isBinaryMult(Hop hop) {
        return (hop instanceof BinaryOp) && ((BinaryOp) hop).getOp() == Types.OpOp2.MULT;
    }

    private static Hop rule_RewriteEMult(Hop hop) {
        if (hop.isVisited()) {
            return hop;
        }
        hop.setVisited();
        if (isBinaryMult(hop) && hop.dimsKnown()) {
            Hop hop2 = hop.getInput().get(0);
            Hop hop3 = hop.getInput().get(1);
            HashSet hashSet = new HashSet();
            HashMap hashMap = new HashMap();
            findEMultsAndLeaves((BinaryOp) hop, hashSet, hashMap);
            if (hashSet.size() >= 2) {
                if ((!isBinaryMult(hop2) || checkForeignParent(hashSet, (BinaryOp) hop2)) && (!isBinaryMult(hop3) || checkForeignParent(hashSet, (BinaryOp) hop3))) {
                    Hop constructReplacement = constructReplacement(hashSet, hashMap);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", Integer.valueOf(hashSet.size()), Long.valueOf(hop.getHopID()), Long.valueOf(constructReplacement.getHopID())));
                    }
                    Hop rewireAllParentChildReferences = HopRewriteUtils.rewireAllParentChildReferences(hop, constructReplacement);
                    Iterator it = hashMap.keySet().iterator();
                    while (it.hasNext()) {
                        recurseInputs((Hop) it.next());
                    }
                    return rewireAllParentChildReferences;
                }
            }
        }
        recurseInputs(hop);
        return hop;
    }

    private static void recurseInputs(Hop hop) {
        ArrayList<Hop> input = hop.getInput();
        for (int i = 0; i < input.size(); i++) {
            input.set(i, rule_RewriteEMult(input.get(i)));
        }
    }

    private static Hop constructReplacement(Set<BinaryOp> set, Map<Hop, Integer> map) {
        TreeSet treeSet = new TreeSet(compareByDataType);
        for (Map.Entry<Hop, Integer> entry : map.entrySet()) {
            Hop key = entry.getKey();
            key.getParent().removeIf(hop -> {
                return (hop instanceof BinaryOp) && set.contains(hop);
            });
            treeSet.add(constructPower(key, entry.getValue().intValue()));
        }
        Iterator it = treeSet.iterator();
        Hop hop2 = it.hasNext() ? (Hop) it.next() : null;
        Hop hop3 = null;
        while (hop2 != null && (hop2.getDataType() == Types.DataType.SCALAR || (hop2.getDataType() == Types.DataType.MATRIX && hop2.getDim2() == 1))) {
            if (hop3 == null) {
                hop3 = hop2;
            } else {
                hop3 = HopRewriteUtils.createBinary(hop2, hop3, Types.OpOp2.MULT);
                hop3.setVisited();
            }
            hop2 = it.hasNext() ? (Hop) it.next() : null;
        }
        Hop hop4 = null;
        while (hop2 != null && hop2.getDataType() == Types.DataType.MATRIX && hop2.getDim1() == 1) {
            if (hop4 == null) {
                hop4 = hop2;
            } else {
                hop4 = HopRewriteUtils.createBinary(hop4, hop2, Types.OpOp2.MULT);
                hop4.setVisited();
            }
            hop2 = it.hasNext() ? (Hop) it.next() : null;
        }
        Hop hop5 = null;
        while (hop2 != null && hop2.getDataType() == Types.DataType.MATRIX) {
            if (hop5 == null) {
                hop5 = hop2;
            } else {
                hop5 = HopRewriteUtils.createBinary(hop5, hop2, Types.OpOp2.MULT);
                hop5.setVisited();
            }
            hop2 = it.hasNext() ? (Hop) it.next() : null;
        }
        Hop hop6 = null;
        while (hop2 != null) {
            if (hop6 == null) {
                hop6 = hop2;
            } else {
                hop6 = HopRewriteUtils.createBinary(hop6, hop2, Types.OpOp2.MULT);
                hop6.setVisited();
            }
            hop2 = it.hasNext() ? (Hop) it.next() : null;
        }
        Hop hop7 = null;
        if (hop6 == null && hop5 != null) {
            hop7 = hop5;
        } else if (hop6 != null && hop5 == null) {
            hop7 = hop6;
        } else if (hop6 != null) {
            hop7 = HopRewriteUtils.createBinary(hop6, hop5, Types.OpOp2.MULT);
            hop7.setVisited();
        }
        if (hop7 == null && hop4 != null) {
            hop7 = hop4;
        } else if (hop4 != null) {
            hop7 = HopRewriteUtils.createBinary(hop7, hop4, Types.OpOp2.MULT);
            hop7.setVisited();
        }
        if (hop7 == null && hop3 != null) {
            hop7 = hop3;
        } else if (hop3 != null) {
            hop7 = HopRewriteUtils.createBinary(hop7, hop3, Types.OpOp2.MULT);
            hop7.setVisited();
        }
        return hop7;
    }

    private static Hop constructPower(Hop hop, int i) {
        if (!$assertionsDisabled && i < 1) {
            throw new AssertionError();
        }
        hop.setVisited();
        if (i == 1) {
            return hop;
        }
        BinaryOp createBinary = HopRewriteUtils.createBinary(hop, new LiteralOp(i), Types.OpOp2.POW);
        createBinary.setVisited();
        return createBinary;
    }

    private static boolean checkForeignParent(Set<BinaryOp> set, BinaryOp binaryOp) {
        ArrayList<Hop> parent = binaryOp.getParent();
        if (parent.size() > 1) {
            Iterator<Hop> it = parent.iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                if (!(next instanceof BinaryOp) || !set.contains(next)) {
                    return false;
                }
            }
        }
        ArrayList<Hop> input = binaryOp.getInput();
        Hop hop = input.get(0);
        Hop hop2 = input.get(1);
        return (!isBinaryMult(hop) || checkForeignParent(set, (BinaryOp) hop)) && (!isBinaryMult(hop2) || checkForeignParent(set, (BinaryOp) hop2));
    }

    private static void findEMultsAndLeaves(BinaryOp binaryOp, Set<BinaryOp> set, Map<Hop, Integer> map) {
        set.add(binaryOp);
        ArrayList<Hop> input = binaryOp.getInput();
        Hop hop = input.get(0);
        Hop hop2 = input.get(1);
        if (isBinaryMult(hop)) {
            findEMultsAndLeaves((BinaryOp) hop, set, map);
        } else {
            addMultiset(map, hop);
        }
        if (isBinaryMult(hop2)) {
            findEMultsAndLeaves((BinaryOp) hop2, set, map);
        } else {
            addMultiset(map, hop2);
        }
    }

    private static <K> void addMultiset(Map<K, Integer> map, K k) {
        map.put(k, Integer.valueOf(map.getOrDefault(k, 0).intValue() + 1));
    }

    static {
        $assertionsDisabled = !RewriteElementwiseMultChainOptimization.class.desiredAssertionStatus();
        compareByDataType = new Comparator<Hop>() { // from class: org.apache.sysds.hops.rewrite.RewriteElementwiseMultChainOptimization.1
            private final int[] orderDataType = new int[Types.DataType.values().length];

            {
                int length = Types.DataType.values().length;
                for (int i = 0; i < length; i++) {
                    switch (AnonymousClass2.$SwitchMap$org$apache$sysds$common$Types$DataType[Types.DataType.values()[i].ordinal()]) {
                        case 1:
                            this.orderDataType[i] = 0;
                            break;
                        case 2:
                            this.orderDataType[i] = 1;
                            break;
                        case 3:
                            this.orderDataType[i] = 2;
                            break;
                        case 4:
                            this.orderDataType[i] = 3;
                            break;
                        case 5:
                            this.orderDataType[i] = 5;
                            break;
                        default:
                            this.orderDataType[i] = 4;
                            break;
                    }
                }
            }

            @Override // java.util.Comparator
            public final int compare(Hop hop, Hop hop2) {
                int compare = Integer.compare(this.orderDataType[hop.getDataType().ordinal()], this.orderDataType[hop2.getDataType().ordinal()]);
                if (compare != 0) {
                    return compare;
                }
                switch (AnonymousClass2.$SwitchMap$org$apache$sysds$common$Types$DataType[hop.getDataType().ordinal()]) {
                    case 2:
                        if (hop.getDim2() == 1) {
                            if (hop2.getDim2() != 1) {
                                return -1;
                            }
                            return compareBySparsityThenId(hop, hop2);
                        }
                        if (hop2.getDim2() == 1) {
                            return 1;
                        }
                        if (hop.getDim1() == 1) {
                            if (hop2.getDim1() != 1) {
                                return -1;
                            }
                            return compareBySparsityThenId(hop, hop2);
                        }
                        if (hop2.getDim1() == 1) {
                            return 1;
                        }
                        return compareBySparsityThenId(hop, hop2);
                    default:
                        return Long.compare(hop.getHopID(), hop2.getHopID());
                }
            }

            private int compareBySparsityThenId(Hop hop, Hop hop2) {
                int compare = Long.compare(hop.getNnz(), hop2.getNnz());
                return compare != 0 ? -compare : Long.compare(hop.getHopID(), hop2.getHopID());
            }
        };
    }
}
