package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibUtils.class */
public final class CLALibUtils {
    protected static final Log LOG = LogFactory.getLog(CLALibUtils.class.getName());

    public static void combineConstColumns(CompressedMatrixBlock compressedMatrixBlock) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (AColGroup aColGroup : compressedMatrixBlock.getColGroups()) {
            if (aColGroup instanceof ColGroupEmpty) {
                arrayList.add(aColGroup);
            } else if (aColGroup instanceof ColGroupConst) {
                arrayList2.add(aColGroup);
            } else {
                arrayList3.add(aColGroup);
            }
        }
        if (arrayList.size() >= 1 || arrayList2.size() >= 1) {
            if (arrayList.size() == 1) {
                arrayList3.add((AColGroup) arrayList.get(0));
            } else if (arrayList.size() > 1) {
                arrayList3.add(combineEmpty(arrayList));
            }
            if (arrayList2.size() == 1) {
                arrayList3.add((AColGroup) arrayList2.get(0));
            } else if (arrayList2.size() > 1) {
                arrayList3.add(combineConst(arrayList2));
            }
            compressedMatrixBlock.allocateColGroupList(arrayList3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean shouldPreFilter(List<AColGroup> list) {
        for (AColGroup aColGroup : list) {
            if ((aColGroup instanceof AMorphingMMColGroup) || (aColGroup instanceof ColGroupConst)) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<AColGroup> filterGroups(List<AColGroup> list, double[] dArr) {
        if (dArr == null) {
            return list;
        }
        ArrayList arrayList = new ArrayList();
        for (AColGroup aColGroup : list) {
            if (aColGroup instanceof AMorphingMMColGroup) {
                arrayList.add(((AMorphingMMColGroup) aColGroup).extractCommon(dArr));
            } else if (!(aColGroup instanceof ColGroupEmpty)) {
                if (aColGroup instanceof ColGroupConst) {
                    ((ColGroupConst) aColGroup).addToCommon(dArr);
                } else {
                    arrayList.add(aColGroup);
                }
            }
        }
        return returnGroupIfFiniteNumbers(list, arrayList, dArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void filterGroupsAndSplitPreAgg(List<AColGroup> list, double[] dArr, List<AColGroup> list2, List<APreAgg> list3) {
        for (AColGroup aColGroup : list) {
            if (aColGroup instanceof APreAgg) {
                list3.add((APreAgg) aColGroup);
            } else if (aColGroup instanceof AMorphingMMColGroup) {
                AColGroup extractCommon = ((AMorphingMMColGroup) aColGroup).extractCommon(dArr);
                if (extractCommon instanceof APreAgg) {
                    list3.add((APreAgg) extractCommon);
                } else if (!(extractCommon instanceof ColGroupEmpty)) {
                    throw new DMLCompressionException("I did not think this was a problem");
                }
            } else if (!(aColGroup instanceof ColGroupEmpty)) {
                if (aColGroup instanceof ColGroupConst) {
                    ((ColGroupConst) aColGroup).addToCommon(dArr);
                } else {
                    list2.add(aColGroup);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void splitPreAgg(List<AColGroup> list, List<AColGroup> list2, List<APreAgg> list3) {
        for (AColGroup aColGroup : list) {
            if (aColGroup instanceof APreAgg) {
                list3.add((APreAgg) aColGroup);
            } else if (aColGroup instanceof ColGroupEmpty) {
                continue;
            } else {
                if (aColGroup instanceof ColGroupConst) {
                    throw new NotImplementedException();
                }
                list2.add(aColGroup);
            }
        }
    }

    private static List<AColGroup> returnGroupIfFiniteNumbers(List<AColGroup> list, List<AColGroup> list2, double[] dArr) {
        for (double d : dArr) {
            if (!Double.isFinite(d)) {
                throw new NotImplementedException("Not handling if the values are not finite: " + Arrays.toString(dArr));
            }
        }
        return list2;
    }

    private static AColGroup combineEmpty(List<AColGroup> list) {
        return new ColGroupEmpty(combineColIndexes(list));
    }

    private static AColGroup combineConst(List<AColGroup> list) {
        int[] combineColIndexes = combineColIndexes(list);
        double[] dArr = new double[combineColIndexes.length];
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            ColGroupConst colGroupConst = (ColGroupConst) it.next();
            int[] colIndices = colGroupConst.getColIndices();
            double[] values = colGroupConst.getValues();
            for (int i = 0; i < colIndices.length; i++) {
                dArr[Arrays.binarySearch(combineColIndexes, colIndices[i])] = values[i];
            }
        }
        return ColGroupConst.create(combineColIndexes, dArr);
    }

    private static int[] combineColIndexes(List<AColGroup> list) {
        int i = 0;
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().getNumCols();
        }
        int[] iArr = new int[i];
        int i2 = 0;
        Iterator<AColGroup> it2 = list.iterator();
        while (it2.hasNext()) {
            for (int i3 : it2.next().getColIndices()) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        Arrays.sort(iArr);
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[] getColSum(List<AColGroup> list, int i, int i2) {
        return AColGroup.colSum(list, new double[i], i2);
    }
}
