package org.nd4j.autodiff.samediff;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.learning.config.IUpdater;

/* loaded from: input_file:org/nd4j/autodiff/samediff/TrainingConfig.class */
public class TrainingConfig {
    private IUpdater updater;
    private double l1;
    private double l2;
    private boolean minimize;
    private List<String> dataSetFeatureMapping;
    private List<String> dataSetLabelMapping;
    private List<String> dataSetFeatureMaskMapping;
    private List<String> dataSetLabelMaskMapping;
    private List<String> trainableParams;
    private int iterationCount;
    private int epochCount;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/TrainingConfig$Builder.class */
    public static class Builder {
        private IUpdater updater;
        private double l1;
        private double l2;
        private boolean minimize = true;
        private List<String> dataSetFeatureMapping;
        private List<String> dataSetLabelMapping;
        private List<String> dataSetFeatureMaskMapping;
        private List<String> dataSetLabelMaskMapping;
        private List<String> trainableParams;

        public Builder updater(IUpdater iUpdater) {
            this.updater = iUpdater;
            return this;
        }

        public Builder l1(double d) {
            Preconditions.checkState(d >= 0.0d, "L1 regularization coefficient must be >= 0. Got %s", d);
            this.l1 = d;
            return this;
        }

        public Builder l2(double d) {
            Preconditions.checkState(d >= 0.0d, "L2 regularization coefficient must be >= 0. Got %s", d);
            this.l2 = d;
            return this;
        }

        public Builder minimize(boolean z) {
            this.minimize = z;
            return this;
        }

        public Builder dataSetFeatureMapping(String... strArr) {
            return dataSetFeatureMapping(Arrays.asList(strArr));
        }

        public Builder dataSetFeatureMapping(List<String> list) {
            Preconditions.checkNotNull(Boolean.valueOf(list != null && list.size() > 0), "No feature mapping was provided");
            this.dataSetFeatureMapping = list;
            return this;
        }

        public Builder dataSetLabelMapping(String... strArr) {
            return dataSetLabelMapping(Arrays.asList(strArr));
        }

        public Builder dataSetLabelMapping(List<String> list) {
            Preconditions.checkNotNull(Boolean.valueOf(list != null && list.size() > 0), "No label mapping was provided");
            this.dataSetLabelMapping = list;
            return this;
        }

        public Builder dataSetFeatureMaskMapping(List<String> list) {
            this.dataSetFeatureMaskMapping = list;
            return this;
        }

        public Builder dataSetLabelMaskMapping(List<String> list) {
            this.dataSetLabelMaskMapping = list;
            return this;
        }

        public Builder trainableParams(String... strArr) {
            return trainableParams(Arrays.asList(strArr));
        }

        public Builder trainableParams(List<String> list) {
            this.trainableParams = list;
            return this;
        }

        public TrainingConfig build() {
            Preconditions.checkState(this.updater != null, "Updater (optimizer) must not be null. Use updater(IUpdater) to set an updater");
            Preconditions.checkState(this.dataSetFeatureMapping != null, "No DataSet feature mapping has been provided. A mapping between DataSet array positions and variables/placeholders must be provided - use  dateSetFeatureMapping(...) to set this");
            Preconditions.checkState(this.dataSetLabelMapping != null, "No DataSet label mapping has been provided. A mapping between DataSet array positions and variables/placeholders must be provided - use  dateSetLabelMapping(...) to set this");
            return new TrainingConfig(this.updater, this.l1, this.l2, this.minimize, this.dataSetFeatureMapping, this.dataSetLabelMapping, this.dataSetFeatureMaskMapping, this.dataSetLabelMaskMapping, this.trainableParams);
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/TrainingConfig$TrainingConfigBuilder.class */
    public static class TrainingConfigBuilder {
        private IUpdater updater;
        private double l1;
        private double l2;
        private boolean minimize;
        private List<String> dataSetFeatureMapping;
        private List<String> dataSetLabelMapping;
        private List<String> dataSetFeatureMaskMapping;
        private List<String> dataSetLabelMaskMapping;
        private List<String> trainableParams;
        private int iterationCount;
        private int epochCount;

        TrainingConfigBuilder() {
        }

        public TrainingConfigBuilder updater(IUpdater iUpdater) {
            this.updater = iUpdater;
            return this;
        }

        public TrainingConfigBuilder l1(double d) {
            this.l1 = d;
            return this;
        }

        public TrainingConfigBuilder l2(double d) {
            this.l2 = d;
            return this;
        }

        public TrainingConfigBuilder minimize(boolean z) {
            this.minimize = z;
            return this;
        }

        public TrainingConfigBuilder dataSetFeatureMapping(List<String> list) {
            this.dataSetFeatureMapping = list;
            return this;
        }

        public TrainingConfigBuilder dataSetLabelMapping(List<String> list) {
            this.dataSetLabelMapping = list;
            return this;
        }

        public TrainingConfigBuilder dataSetFeatureMaskMapping(List<String> list) {
            this.dataSetFeatureMaskMapping = list;
            return this;
        }

        public TrainingConfigBuilder dataSetLabelMaskMapping(List<String> list) {
            this.dataSetLabelMaskMapping = list;
            return this;
        }

        public TrainingConfigBuilder trainableParams(List<String> list) {
            this.trainableParams = list;
            return this;
        }

        public TrainingConfigBuilder iterationCount(int i) {
            this.iterationCount = i;
            return this;
        }

        public TrainingConfigBuilder epochCount(int i) {
            this.epochCount = i;
            return this;
        }

        public TrainingConfig build() {
            return new TrainingConfig(this.updater, this.l1, this.l2, this.minimize, this.dataSetFeatureMapping, this.dataSetLabelMapping, this.dataSetFeatureMaskMapping, this.dataSetLabelMaskMapping, this.trainableParams, this.iterationCount, this.epochCount);
        }

        public String toString() {
            return "TrainingConfig.TrainingConfigBuilder(updater=" + this.updater + ", l1=" + this.l1 + ", l2=" + this.l2 + ", minimize=" + this.minimize + ", dataSetFeatureMapping=" + this.dataSetFeatureMapping + ", dataSetLabelMapping=" + this.dataSetLabelMapping + ", dataSetFeatureMaskMapping=" + this.dataSetFeatureMaskMapping + ", dataSetLabelMaskMapping=" + this.dataSetLabelMaskMapping + ", trainableParams=" + this.trainableParams + ", iterationCount=" + this.iterationCount + ", epochCount=" + this.epochCount + ")";
        }
    }

    public TrainingConfig(IUpdater iUpdater, double d, double d2, String str, String str2) {
        this(iUpdater, d, d2, true, Collections.singletonList(str), Collections.singletonList(str2), Collections.emptyList(), Collections.emptyList(), null);
    }

    public TrainingConfig(IUpdater iUpdater, double d, double d2, boolean z, List<String> list, List<String> list2, List<String> list3, List<String> list4, List<String> list5) {
        this.minimize = true;
        this.updater = iUpdater;
        this.l1 = d;
        this.l2 = d2;
        this.minimize = z;
        this.dataSetFeatureMapping = list;
        this.dataSetLabelMapping = list2;
        this.dataSetFeatureMaskMapping = list3;
        this.dataSetLabelMaskMapping = list4;
        this.trainableParams = list5;
    }

    public void incrementIterationCount() {
        this.iterationCount++;
    }

    public void incrementEpochCount() {
        this.epochCount++;
    }

    public static Builder builder() {
        return new Builder();
    }

    public int labelIdx(String str) {
        return this.dataSetLabelMapping.indexOf(str);
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public boolean isMinimize() {
        return this.minimize;
    }

    public List<String> getDataSetFeatureMapping() {
        return this.dataSetFeatureMapping;
    }

    public List<String> getDataSetLabelMapping() {
        return this.dataSetLabelMapping;
    }

    public List<String> getDataSetFeatureMaskMapping() {
        return this.dataSetFeatureMaskMapping;
    }

    public List<String> getDataSetLabelMaskMapping() {
        return this.dataSetLabelMaskMapping;
    }

    public List<String> getTrainableParams() {
        return this.trainableParams;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public int getEpochCount() {
        return this.epochCount;
    }

    public void setUpdater(IUpdater iUpdater) {
        this.updater = iUpdater;
    }

    public void setL1(double d) {
        this.l1 = d;
    }

    public void setL2(double d) {
        this.l2 = d;
    }

    public void setMinimize(boolean z) {
        this.minimize = z;
    }

    public void setDataSetFeatureMapping(List<String> list) {
        this.dataSetFeatureMapping = list;
    }

    public void setDataSetLabelMapping(List<String> list) {
        this.dataSetLabelMapping = list;
    }

    public void setDataSetFeatureMaskMapping(List<String> list) {
        this.dataSetFeatureMaskMapping = list;
    }

    public void setDataSetLabelMaskMapping(List<String> list) {
        this.dataSetLabelMaskMapping = list;
    }

    public void setTrainableParams(List<String> list) {
        this.trainableParams = list;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setEpochCount(int i) {
        this.epochCount = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TrainingConfig)) {
            return false;
        }
        TrainingConfig trainingConfig = (TrainingConfig) obj;
        if (!trainingConfig.canEqual(this)) {
            return false;
        }
        IUpdater updater = getUpdater();
        IUpdater updater2 = trainingConfig.getUpdater();
        if (updater == null) {
            if (updater2 != null) {
                return false;
            }
        } else if (!updater.equals(updater2)) {
            return false;
        }
        if (Double.compare(getL1(), trainingConfig.getL1()) != 0 || Double.compare(getL2(), trainingConfig.getL2()) != 0 || isMinimize() != trainingConfig.isMinimize()) {
            return false;
        }
        List<String> dataSetFeatureMapping = getDataSetFeatureMapping();
        List<String> dataSetFeatureMapping2 = trainingConfig.getDataSetFeatureMapping();
        if (dataSetFeatureMapping == null) {
            if (dataSetFeatureMapping2 != null) {
                return false;
            }
        } else if (!dataSetFeatureMapping.equals(dataSetFeatureMapping2)) {
            return false;
        }
        List<String> dataSetLabelMapping = getDataSetLabelMapping();
        List<String> dataSetLabelMapping2 = trainingConfig.getDataSetLabelMapping();
        if (dataSetLabelMapping == null) {
            if (dataSetLabelMapping2 != null) {
                return false;
            }
        } else if (!dataSetLabelMapping.equals(dataSetLabelMapping2)) {
            return false;
        }
        List<String> dataSetFeatureMaskMapping = getDataSetFeatureMaskMapping();
        List<String> dataSetFeatureMaskMapping2 = trainingConfig.getDataSetFeatureMaskMapping();
        if (dataSetFeatureMaskMapping == null) {
            if (dataSetFeatureMaskMapping2 != null) {
                return false;
            }
        } else if (!dataSetFeatureMaskMapping.equals(dataSetFeatureMaskMapping2)) {
            return false;
        }
        List<String> dataSetLabelMaskMapping = getDataSetLabelMaskMapping();
        List<String> dataSetLabelMaskMapping2 = trainingConfig.getDataSetLabelMaskMapping();
        if (dataSetLabelMaskMapping == null) {
            if (dataSetLabelMaskMapping2 != null) {
                return false;
            }
        } else if (!dataSetLabelMaskMapping.equals(dataSetLabelMaskMapping2)) {
            return false;
        }
        List<String> trainableParams = getTrainableParams();
        List<String> trainableParams2 = trainingConfig.getTrainableParams();
        if (trainableParams == null) {
            if (trainableParams2 != null) {
                return false;
            }
        } else if (!trainableParams.equals(trainableParams2)) {
            return false;
        }
        return getIterationCount() == trainingConfig.getIterationCount() && getEpochCount() == trainingConfig.getEpochCount();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof TrainingConfig;
    }

    public int hashCode() {
        IUpdater updater = getUpdater();
        int hashCode = (1 * 59) + (updater == null ? 43 : updater.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getL1());
        int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getL2());
        int i2 = (((i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2))) * 59) + (isMinimize() ? 79 : 97);
        List<String> dataSetFeatureMapping = getDataSetFeatureMapping();
        int hashCode2 = (i2 * 59) + (dataSetFeatureMapping == null ? 43 : dataSetFeatureMapping.hashCode());
        List<String> dataSetLabelMapping = getDataSetLabelMapping();
        int hashCode3 = (hashCode2 * 59) + (dataSetLabelMapping == null ? 43 : dataSetLabelMapping.hashCode());
        List<String> dataSetFeatureMaskMapping = getDataSetFeatureMaskMapping();
        int hashCode4 = (hashCode3 * 59) + (dataSetFeatureMaskMapping == null ? 43 : dataSetFeatureMaskMapping.hashCode());
        List<String> dataSetLabelMaskMapping = getDataSetLabelMaskMapping();
        int hashCode5 = (hashCode4 * 59) + (dataSetLabelMaskMapping == null ? 43 : dataSetLabelMaskMapping.hashCode());
        List<String> trainableParams = getTrainableParams();
        return (((((hashCode5 * 59) + (trainableParams == null ? 43 : trainableParams.hashCode())) * 59) + getIterationCount()) * 59) + getEpochCount();
    }

    public String toString() {
        return "TrainingConfig(updater=" + getUpdater() + ", l1=" + getL1() + ", l2=" + getL2() + ", minimize=" + isMinimize() + ", dataSetFeatureMapping=" + getDataSetFeatureMapping() + ", dataSetLabelMapping=" + getDataSetLabelMapping() + ", dataSetFeatureMaskMapping=" + getDataSetFeatureMaskMapping() + ", dataSetLabelMaskMapping=" + getDataSetLabelMaskMapping() + ", trainableParams=" + getTrainableParams() + ", iterationCount=" + getIterationCount() + ", epochCount=" + getEpochCount() + ")";
    }

    public TrainingConfig() {
        this.minimize = true;
    }

    public TrainingConfig(IUpdater iUpdater, double d, double d2, boolean z, List<String> list, List<String> list2, List<String> list3, List<String> list4, List<String> list5, int i, int i2) {
        this.minimize = true;
        this.updater = iUpdater;
        this.l1 = d;
        this.l2 = d2;
        this.minimize = z;
        this.dataSetFeatureMapping = list;
        this.dataSetLabelMapping = list2;
        this.dataSetFeatureMaskMapping = list3;
        this.dataSetLabelMaskMapping = list4;
        this.trainableParams = list5;
        this.iterationCount = i;
        this.epochCount = i2;
    }
}
