package org.apache.sysds.hops.cost;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;

/* loaded from: input_file:org/apache/sysds/hops/cost/ComputeCost.class */
public class ComputeCost {
    private static final Log LOG = LogFactory.getLog(ComputeCost.class.getName());

    public static double getHOPComputeCost(Hop hop) {
        double d = 1.0d;
        if (!(hop instanceof UnaryOp)) {
            if (!(hop instanceof BinaryOp)) {
                if (!(hop instanceof TernaryOp)) {
                    if (!(hop instanceof NaryOp)) {
                        if (!(hop instanceof ParameterizedBuiltinOp)) {
                            if (!(hop instanceof IndexingOp)) {
                                if (!(hop instanceof ReorgOp)) {
                                    if (!(hop instanceof DnnOp)) {
                                        if (!(hop instanceof AggBinaryOp)) {
                                            if (hop instanceof AggUnaryOp) {
                                                switch (((AggUnaryOp) hop).getOp()) {
                                                    case SUM:
                                                        d = 4.0d;
                                                        break;
                                                    case SUM_SQ:
                                                        d = 5.0d;
                                                        break;
                                                    case MIN:
                                                    case MAX:
                                                        d = 1.0d;
                                                        break;
                                                    default:
                                                        LOG.warn("Cost model not implemented yet for: " + ((AggUnaryOp) hop).getOp());
                                                        break;
                                                }
                                                switch (((AggUnaryOp) hop).getDirection()) {
                                                    case Col:
                                                        d *= Math.max(hop.getInput().get(0).getDim1(), 1L);
                                                        break;
                                                    case Row:
                                                        d *= Math.max(hop.getInput().get(0).getDim2(), 1L);
                                                        break;
                                                    case RowCol:
                                                        d *= getSize(hop.getInput().get(0));
                                                        break;
                                                }
                                            }
                                        } else {
                                            d = 2 * hop.getInput().get(0).getDim2();
                                            if (hop.getInput().get(0).dimsKnown(true)) {
                                                d *= hop.getInput().get(0).getSparsity();
                                            }
                                        }
                                    } else {
                                        switch (((DnnOp) hop).getOp()) {
                                            case BIASADD:
                                            case BIASMULT:
                                                d = 2.0d;
                                                break;
                                            default:
                                                LOG.warn("Cost model not implemented yet for: " + ((DnnOp) hop).getOp());
                                                break;
                                        }
                                    }
                                } else {
                                    d = 1.0d;
                                }
                            } else {
                                d = 1.0d;
                            }
                        } else {
                            d = 1.0d;
                        }
                    } else {
                        d = HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ? hop.getInput().size() : 1.0d;
                    }
                } else {
                    switch (((TernaryOp) hop).getOp()) {
                        case IFELSE:
                        case PLUS_MULT:
                        case MINUS_MULT:
                            d = 2.0d;
                            break;
                        case CTABLE:
                            d = 3.0d;
                            break;
                        case MOMENT:
                            switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                                case 0:
                                    d = 2.0d;
                                    break;
                                case 1:
                                    d = 9.0d;
                                    break;
                                case 2:
                                    d = 17.0d;
                                    break;
                                case 3:
                                    d = 32.0d;
                                    break;
                                case 4:
                                    d = 52.0d;
                                    break;
                                case 5:
                                    d = 17.0d;
                                    break;
                            }
                        case COV:
                            d = 23.0d;
                            break;
                        default:
                            LOG.warn("Cost model not implemented yet for: " + ((TernaryOp) hop).getOp());
                            break;
                    }
                }
            } else {
                switch (((BinaryOp) hop).getOp()) {
                    case MULT:
                    case PLUS:
                    case MINUS:
                    case MIN:
                    case MAX:
                    case AND:
                    case OR:
                    case EQUAL:
                    case NOTEQUAL:
                    case LESS:
                    case LESSEQUAL:
                    case GREATER:
                    case GREATEREQUAL:
                    case CBIND:
                    case RBIND:
                        d = 1.0d;
                        break;
                    case INTDIV:
                        d = 6.0d;
                        break;
                    case MODULUS:
                        d = 8.0d;
                        break;
                    case DIV:
                        d = 22.0d;
                        break;
                    case LOG:
                    case LOG_NZ:
                        d = 32.0d;
                        break;
                    case POW:
                        d = HopRewriteUtils.isLiteralOfValue(hop.getInput().get(1), 2.0d) ? 1 : 16;
                        break;
                    case MINUS_NZ:
                    case MINUS1_MULT:
                        d = 2.0d;
                        break;
                    case MOMENT:
                        switch ((int) (hop.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) hop.getInput().get(1)) : 2L)) {
                            case 0:
                                d = 1.0d;
                                break;
                            case 1:
                                d = 8.0d;
                                break;
                            case 2:
                                d = 16.0d;
                                break;
                            case 3:
                                d = 31.0d;
                                break;
                            case 4:
                                d = 51.0d;
                                break;
                            case 5:
                                d = 16.0d;
                                break;
                        }
                    case COV:
                        d = 23.0d;
                        break;
                    default:
                        LOG.warn("Cost model not implemented yet for: " + ((BinaryOp) hop).getOp());
                        break;
                }
            }
        } else {
            switch (((UnaryOp) hop).getOp()) {
                case ABS:
                case ROUND:
                case CEIL:
                case FLOOR:
                case SIGN:
                    d = 1.0d;
                    break;
                case SPROP:
                case SQRT:
                    d = 2.0d;
                    break;
                case EXP:
                    d = 18.0d;
                    break;
                case SIGMOID:
                    d = 21.0d;
                    break;
                case LOG:
                case LOG_NZ:
                    d = 32.0d;
                    break;
                case NCOL:
                case NROW:
                case PRINT:
                case ASSERT:
                case CAST_AS_BOOLEAN:
                case CAST_AS_DOUBLE:
                case CAST_AS_INT:
                case CAST_AS_MATRIX:
                case CAST_AS_SCALAR:
                    d = 1.0d;
                    break;
                case SIN:
                    d = 18.0d;
                    break;
                case COS:
                    d = 22.0d;
                    break;
                case TAN:
                    d = 42.0d;
                    break;
                case ASIN:
                    d = 93.0d;
                    break;
                case ACOS:
                    d = 103.0d;
                    break;
                case ATAN:
                    d = 40.0d;
                    break;
                case SINH:
                    d = 93.0d;
                    break;
                case COSH:
                    d = 103.0d;
                    break;
                case TANH:
                    d = 40.0d;
                    break;
                case CUMSUM:
                case CUMMIN:
                case CUMMAX:
                case CUMPROD:
                    d = 1.0d;
                    break;
                case CUMSUMPROD:
                    d = 2.0d;
                    break;
                default:
                    LOG.warn("Cost model not implemented yet for: " + ((UnaryOp) hop).getOp());
                    break;
            }
        }
        return d * getSize(hop);
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L);
    }
}
