package org.nd4j.linalg.util;

import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.BiFunction;
import org.nd4j.linalg.primitives.Triple;

/* loaded from: input_file:org/nd4j/linalg/util/ND4JTestUtils.class */
public class ND4JTestUtils {

    /* loaded from: input_file:org/nd4j/linalg/util/ND4JTestUtils$ComparisonResult.class */
    public static class ComparisonResult {
        List<Triple<File, File, Boolean>> allResults;
        List<Triple<File, File, Boolean>> passed;
        List<Triple<File, File, Boolean>> failed;
        List<File> skippedDir1;
        List<File> skippedDir2;

        public ComparisonResult(List<Triple<File, File, Boolean>> list, List<Triple<File, File, Boolean>> list2, List<Triple<File, File, Boolean>> list3, List<File> list4, List<File> list5) {
            this.allResults = list;
            this.passed = list2;
            this.failed = list3;
            this.skippedDir1 = list4;
            this.skippedDir2 = list5;
        }

        public List<Triple<File, File, Boolean>> getAllResults() {
            return this.allResults;
        }

        public List<Triple<File, File, Boolean>> getPassed() {
            return this.passed;
        }

        public List<Triple<File, File, Boolean>> getFailed() {
            return this.failed;
        }

        public List<File> getSkippedDir1() {
            return this.skippedDir1;
        }

        public List<File> getSkippedDir2() {
            return this.skippedDir2;
        }

        public void setAllResults(List<Triple<File, File, Boolean>> list) {
            this.allResults = list;
        }

        public void setPassed(List<Triple<File, File, Boolean>> list) {
            this.passed = list;
        }

        public void setFailed(List<Triple<File, File, Boolean>> list) {
            this.failed = list;
        }

        public void setSkippedDir1(List<File> list) {
            this.skippedDir1 = list;
        }

        public void setSkippedDir2(List<File> list) {
            this.skippedDir2 = list;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ComparisonResult)) {
                return false;
            }
            ComparisonResult comparisonResult = (ComparisonResult) obj;
            if (!comparisonResult.canEqual(this)) {
                return false;
            }
            List<Triple<File, File, Boolean>> allResults = getAllResults();
            List<Triple<File, File, Boolean>> allResults2 = comparisonResult.getAllResults();
            if (allResults == null) {
                if (allResults2 != null) {
                    return false;
                }
            } else if (!allResults.equals(allResults2)) {
                return false;
            }
            List<Triple<File, File, Boolean>> passed = getPassed();
            List<Triple<File, File, Boolean>> passed2 = comparisonResult.getPassed();
            if (passed == null) {
                if (passed2 != null) {
                    return false;
                }
            } else if (!passed.equals(passed2)) {
                return false;
            }
            List<Triple<File, File, Boolean>> failed = getFailed();
            List<Triple<File, File, Boolean>> failed2 = comparisonResult.getFailed();
            if (failed == null) {
                if (failed2 != null) {
                    return false;
                }
            } else if (!failed.equals(failed2)) {
                return false;
            }
            List<File> skippedDir1 = getSkippedDir1();
            List<File> skippedDir12 = comparisonResult.getSkippedDir1();
            if (skippedDir1 == null) {
                if (skippedDir12 != null) {
                    return false;
                }
            } else if (!skippedDir1.equals(skippedDir12)) {
                return false;
            }
            List<File> skippedDir2 = getSkippedDir2();
            List<File> skippedDir22 = comparisonResult.getSkippedDir2();
            return skippedDir2 == null ? skippedDir22 == null : skippedDir2.equals(skippedDir22);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ComparisonResult;
        }

        public int hashCode() {
            List<Triple<File, File, Boolean>> allResults = getAllResults();
            int hashCode = (1 * 59) + (allResults == null ? 43 : allResults.hashCode());
            List<Triple<File, File, Boolean>> passed = getPassed();
            int hashCode2 = (hashCode * 59) + (passed == null ? 43 : passed.hashCode());
            List<Triple<File, File, Boolean>> failed = getFailed();
            int hashCode3 = (hashCode2 * 59) + (failed == null ? 43 : failed.hashCode());
            List<File> skippedDir1 = getSkippedDir1();
            int hashCode4 = (hashCode3 * 59) + (skippedDir1 == null ? 43 : skippedDir1.hashCode());
            List<File> skippedDir2 = getSkippedDir2();
            return (hashCode4 * 59) + (skippedDir2 == null ? 43 : skippedDir2.hashCode());
        }

        public String toString() {
            return "ND4JTestUtils.ComparisonResult(allResults=" + getAllResults() + ", passed=" + getPassed() + ", failed=" + getFailed() + ", skippedDir1=" + getSkippedDir1() + ", skippedDir2=" + getSkippedDir2() + ")";
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/util/ND4JTestUtils$EqualsFn.class */
    public static class EqualsFn implements BiFunction<INDArray, INDArray, Boolean> {
        public Boolean apply(INDArray iNDArray, INDArray iNDArray2) {
            return Boolean.valueOf(iNDArray.equals(iNDArray2));
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/util/ND4JTestUtils$EqualsWithEpsFn.class */
    public static class EqualsWithEpsFn implements BiFunction<INDArray, INDArray, Boolean> {
        private final double eps;

        public Boolean apply(INDArray iNDArray, INDArray iNDArray2) {
            return Boolean.valueOf(iNDArray.equalsWithEps(iNDArray2, this.eps));
        }

        public EqualsWithEpsFn(double d) {
            this.eps = d;
        }
    }

    private ND4JTestUtils() {
    }

    public static ComparisonResult validateSerializedArrays(File file, File file2, boolean z) throws Exception {
        return validateSerializedArrays(file, file2, z, new EqualsFn());
    }

    public static ComparisonResult validateSerializedArrays(File file, File file2, boolean z, BiFunction<INDArray, INDArray, Boolean> biFunction) throws Exception {
        File[] fileArr = (File[]) FileUtils.listFiles(file, (String[]) null, z).toArray(new File[0]);
        File[] fileArr2 = (File[]) FileUtils.listFiles(file2, (String[]) null, z).toArray(new File[0]);
        Preconditions.checkState(fileArr.length > 0, "No files found for directory 1: %s", file.getAbsolutePath());
        Preconditions.checkState(fileArr2.length > 0, "No files found for directory 2: %s", file2.getAbsolutePath());
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        URI uri = file.toURI();
        for (File file3 : fileArr) {
            if (file3.isFile()) {
                hashMap.put(uri.relativize(file3.toURI()).getPath(), file3);
            }
        }
        URI uri2 = file2.toURI();
        for (File file4 : fileArr2) {
            if (file4.isFile()) {
                hashMap2.put(uri2.relativize(file4.toURI()).getPath(), file4);
            }
        }
        ArrayList arrayList = new ArrayList();
        for (String str : hashMap.keySet()) {
            if (!hashMap2.containsKey(str)) {
                arrayList.add(hashMap.get(str));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str2 : hashMap2.keySet()) {
            if (!hashMap.containsKey(str2)) {
                arrayList2.add(hashMap.get(str2));
            }
        }
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        for (Map.Entry entry : hashMap.entrySet()) {
            File file5 = (File) entry.getValue();
            File file6 = (File) hashMap2.get(entry.getKey());
            if (file6 != null) {
                boolean booleanValue = ((Boolean) biFunction.apply(Nd4j.readBinary(file5), Nd4j.readBinary(file6))).booleanValue();
                Triple triple = new Triple(file5, file6, Boolean.valueOf(booleanValue));
                arrayList3.add(triple);
                if (booleanValue) {
                    arrayList4.add(triple);
                } else {
                    arrayList5.add(triple);
                }
            }
        }
        Comparator<Triple<File, File, Boolean>> comparator = new Comparator<Triple<File, File, Boolean>>() { // from class: org.nd4j.linalg.util.ND4JTestUtils.1
            @Override // java.util.Comparator
            public int compare(Triple<File, File, Boolean> triple2, Triple<File, File, Boolean> triple3) {
                return ((File) triple2.getFirst()).compareTo((File) triple3.getFirst());
            }
        };
        Collections.sort(arrayList3, comparator);
        Collections.sort(arrayList4, comparator);
        Collections.sort(arrayList5, comparator);
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        return new ComparisonResult(arrayList3, arrayList4, arrayList5, arrayList, arrayList2);
    }
}
