package org.nd4j.autodiff.listeners;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.shade.guava.collect.Sets;

/* loaded from: input_file:org/nd4j/autodiff/listeners/ListenerVariables.class */
public class ListenerVariables {

    @NonNull
    private Set<String> trainingVariables;

    @NonNull
    private Set<String> validationVariables;

    @NonNull
    private Set<String> evaluationVariables;

    @NonNull
    private Set<String> inferenceVariables;

    /* loaded from: input_file:org/nd4j/autodiff/listeners/ListenerVariables$Builder.class */
    public static class Builder {

        @NonNull
        private Set<String> trainingVariables = new HashSet();

        @NonNull
        private Set<String> validationVariables = new HashSet();

        @NonNull
        private Set<String> evaluationVariables = new HashSet();

        @NonNull
        private Set<String> inferenceVariables = new HashSet();

        public Builder requireVariables(@NonNull Operation operation, @NonNull String... strArr) {
            if (operation == null) {
                throw new NullPointerException("op is marked @NonNull but is null");
            }
            if (strArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            switch (operation) {
                case TRAINING:
                    this.trainingVariables.addAll(Arrays.asList(strArr));
                    break;
                case TRAINING_VALIDATION:
                    this.validationVariables.addAll(Arrays.asList(strArr));
                    break;
                case INFERENCE:
                    this.inferenceVariables.addAll(Arrays.asList(strArr));
                    break;
                case EVALUATION:
                    this.evaluationVariables.addAll(Arrays.asList(strArr));
                    break;
            }
            return this;
        }

        public Builder requireVariables(@NonNull Operation operation, @NonNull SDVariable... sDVariableArr) {
            if (operation == null) {
                throw new NullPointerException("op is marked @NonNull but is null");
            }
            if (sDVariableArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            String[] strArr = new String[sDVariableArr.length];
            for (int i = 0; i < sDVariableArr.length; i++) {
                strArr[i] = sDVariableArr[i].name();
            }
            return requireVariables(operation, strArr);
        }

        public Builder trainingVariables(@NonNull String... strArr) {
            if (strArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.TRAINING, strArr);
        }

        public Builder trainingVariables(@NonNull SDVariable... sDVariableArr) {
            if (sDVariableArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.TRAINING, sDVariableArr);
        }

        public Builder validationVariables(@NonNull String... strArr) {
            if (strArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.TRAINING_VALIDATION, strArr);
        }

        public Builder validationVariables(@NonNull SDVariable... sDVariableArr) {
            if (sDVariableArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.TRAINING_VALIDATION, sDVariableArr);
        }

        public Builder inferenceVariables(@NonNull String... strArr) {
            if (strArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.INFERENCE, strArr);
        }

        public Builder inferenceVariables(@NonNull SDVariable... sDVariableArr) {
            if (sDVariableArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.INFERENCE, sDVariableArr);
        }

        public Builder evaluationVariables(@NonNull String... strArr) {
            if (strArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.EVALUATION, strArr);
        }

        public Builder evaluationVariables(@NonNull SDVariable... sDVariableArr) {
            if (sDVariableArr == null) {
                throw new NullPointerException("variables is marked @NonNull but is null");
            }
            return requireVariables(Operation.EVALUATION, sDVariableArr);
        }

        public ListenerVariables build() {
            return new ListenerVariables(this.trainingVariables, this.validationVariables, this.evaluationVariables, this.inferenceVariables);
        }

        @NonNull
        public Set<String> getTrainingVariables() {
            return this.trainingVariables;
        }

        @NonNull
        public Set<String> getValidationVariables() {
            return this.validationVariables;
        }

        @NonNull
        public Set<String> getEvaluationVariables() {
            return this.evaluationVariables;
        }

        @NonNull
        public Set<String> getInferenceVariables() {
            return this.inferenceVariables;
        }

        public void setTrainingVariables(@NonNull Set<String> set) {
            if (set == null) {
                throw new NullPointerException("trainingVariables is marked @NonNull but is null");
            }
            this.trainingVariables = set;
        }

        public void setValidationVariables(@NonNull Set<String> set) {
            if (set == null) {
                throw new NullPointerException("validationVariables is marked @NonNull but is null");
            }
            this.validationVariables = set;
        }

        public void setEvaluationVariables(@NonNull Set<String> set) {
            if (set == null) {
                throw new NullPointerException("evaluationVariables is marked @NonNull but is null");
            }
            this.evaluationVariables = set;
        }

        public void setInferenceVariables(@NonNull Set<String> set) {
            if (set == null) {
                throw new NullPointerException("inferenceVariables is marked @NonNull but is null");
            }
            this.inferenceVariables = set;
        }
    }

    public static ListenerVariables empty() {
        return builder().build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public Set<String> trainingVariables() {
        return this.trainingVariables;
    }

    public Set<String> validationVariables() {
        return this.validationVariables;
    }

    public Set<String> evaluationVariables() {
        return this.evaluationVariables;
    }

    public Set<String> inferenceVariables() {
        return this.inferenceVariables;
    }

    public Set<String> requiredVariables(Operation operation) {
        switch (operation) {
            case TRAINING:
                return this.trainingVariables;
            case TRAINING_VALIDATION:
                return this.validationVariables;
            case INFERENCE:
                return this.inferenceVariables;
            case EVALUATION:
                return this.evaluationVariables;
            default:
                throw new IllegalArgumentException("Unknown operation " + operation);
        }
    }

    private ListenerVariables() {
    }

    public ListenerVariables merge(ListenerVariables listenerVariables) {
        return new ListenerVariables(Sets.newHashSet(Sets.union(this.trainingVariables, listenerVariables.trainingVariables)), Sets.newHashSet(Sets.union(this.validationVariables, listenerVariables.validationVariables)), Sets.newHashSet(Sets.union(this.evaluationVariables, listenerVariables.evaluationVariables)), Sets.newHashSet(Sets.union(this.inferenceVariables, listenerVariables.inferenceVariables)));
    }

    public ListenerVariables(@NonNull Set<String> set, @NonNull Set<String> set2, @NonNull Set<String> set3, @NonNull Set<String> set4) {
        if (set == null) {
            throw new NullPointerException("trainingVariables is marked @NonNull but is null");
        }
        if (set2 == null) {
            throw new NullPointerException("validationVariables is marked @NonNull but is null");
        }
        if (set3 == null) {
            throw new NullPointerException("evaluationVariables is marked @NonNull but is null");
        }
        if (set4 == null) {
            throw new NullPointerException("inferenceVariables is marked @NonNull but is null");
        }
        this.trainingVariables = set;
        this.validationVariables = set2;
        this.evaluationVariables = set3;
        this.inferenceVariables = set4;
    }

    @NonNull
    public Set<String> getTrainingVariables() {
        return this.trainingVariables;
    }

    @NonNull
    public Set<String> getValidationVariables() {
        return this.validationVariables;
    }

    @NonNull
    public Set<String> getEvaluationVariables() {
        return this.evaluationVariables;
    }

    @NonNull
    public Set<String> getInferenceVariables() {
        return this.inferenceVariables;
    }
}
