package org.nd4j.linalg.dataset.api.preprocessor.serializer;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.class */
public class NormalizerSerializer {
    private static final String HEADER = "NORMALIZER";
    private static NormalizerSerializer defaultSerializer;
    private List<NormalizerSerializerStrategy> strategies = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer$Header.class */
    public static final class Header {
        private final NormalizerType normalizerType;
        private final Class<? extends NormalizerSerializerStrategy> customStrategyClass;

        public static Header fromStrategy(NormalizerSerializerStrategy normalizerSerializerStrategy) {
            return normalizerSerializerStrategy instanceof CustomSerializerStrategy ? new Header(normalizerSerializerStrategy.getSupportedType(), normalizerSerializerStrategy.getClass()) : new Header(normalizerSerializerStrategy.getSupportedType(), null);
        }

        public Header(NormalizerType normalizerType, Class<? extends NormalizerSerializerStrategy> cls) {
            this.normalizerType = normalizerType;
            this.customStrategyClass = cls;
        }

        public NormalizerType getNormalizerType() {
            return this.normalizerType;
        }

        public Class<? extends NormalizerSerializerStrategy> getCustomStrategyClass() {
            return this.customStrategyClass;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Header)) {
                return false;
            }
            Header header = (Header) obj;
            NormalizerType normalizerType = getNormalizerType();
            NormalizerType normalizerType2 = header.getNormalizerType();
            if (normalizerType == null) {
                if (normalizerType2 != null) {
                    return false;
                }
            } else if (!normalizerType.equals(normalizerType2)) {
                return false;
            }
            Class<? extends NormalizerSerializerStrategy> customStrategyClass = getCustomStrategyClass();
            Class<? extends NormalizerSerializerStrategy> customStrategyClass2 = header.getCustomStrategyClass();
            return customStrategyClass == null ? customStrategyClass2 == null : customStrategyClass.equals(customStrategyClass2);
        }

        public int hashCode() {
            NormalizerType normalizerType = getNormalizerType();
            int hashCode = (1 * 59) + (normalizerType == null ? 43 : normalizerType.hashCode());
            Class<? extends NormalizerSerializerStrategy> customStrategyClass = getCustomStrategyClass();
            return (hashCode * 59) + (customStrategyClass == null ? 43 : customStrategyClass.hashCode());
        }

        public String toString() {
            return "NormalizerSerializer.Header(normalizerType=" + getNormalizerType() + ", customStrategyClass=" + getCustomStrategyClass() + ")";
        }
    }

    public void write(@NonNull Normalizer normalizer, @NonNull File file) throws IOException {
        if (normalizer == null) {
            throw new NullPointerException("normalizer");
        }
        if (file == null) {
            throw new NullPointerException("file");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        Throwable th = null;
        try {
            write(normalizer, bufferedOutputStream);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public void write(@NonNull Normalizer normalizer, @NonNull String str) throws IOException {
        if (normalizer == null) {
            throw new NullPointerException("normalizer");
        }
        if (str == null) {
            throw new NullPointerException("path");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
        Throwable th = null;
        try {
            write(normalizer, bufferedOutputStream);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public void write(@NonNull Normalizer normalizer, @NonNull OutputStream outputStream) throws IOException {
        if (normalizer == null) {
            throw new NullPointerException("normalizer");
        }
        if (outputStream == null) {
            throw new NullPointerException("stream");
        }
        NormalizerSerializerStrategy strategy = getStrategy(normalizer);
        writeHeader(outputStream, Header.fromStrategy(strategy));
        strategy.write(normalizer, outputStream);
    }

    public <T extends Normalizer> T restore(@NonNull String str) throws Exception {
        if (str == null) {
            throw new NullPointerException("path");
        }
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(str));
        Throwable th = null;
        try {
            T t = (T) restore(bufferedInputStream);
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            return t;
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public <T extends Normalizer> T restore(@NonNull File file) throws Exception {
        if (file == null) {
            throw new NullPointerException("file");
        }
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
        Throwable th = null;
        try {
            T t = (T) restore(bufferedInputStream);
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            return t;
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public <T extends Normalizer> T restore(@NonNull InputStream inputStream) throws Exception {
        if (inputStream == null) {
            throw new NullPointerException("stream");
        }
        return (T) getStrategy(parseHeader(inputStream)).restore(inputStream);
    }

    public static NormalizerSerializer getDefault() {
        if (defaultSerializer == null) {
            defaultSerializer = new NormalizerSerializer().addStrategy(new StandardizeSerializerStrategy()).addStrategy(new MinMaxSerializerStrategy()).addStrategy(new MultiStandardizeSerializerStrategy()).addStrategy(new MultiMinMaxSerializerStrategy()).addStrategy(new MultiHybridSerializerStrategy());
        }
        return defaultSerializer;
    }

    public NormalizerSerializer addStrategy(@NonNull NormalizerSerializerStrategy normalizerSerializerStrategy) {
        if (normalizerSerializerStrategy == null) {
            throw new NullPointerException("strategy");
        }
        this.strategies.add(normalizerSerializerStrategy);
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NormalizerSerializerStrategy getStrategy(Normalizer normalizer) {
        for (NormalizerSerializerStrategy normalizerSerializerStrategy : this.strategies) {
            if (strategySupportsNormalizer(normalizerSerializerStrategy, normalizer.getType(), normalizer.getClass())) {
                return normalizerSerializerStrategy;
            }
        }
        throw new RuntimeException(String.format("No serializer strategy found for normalizer of class %s. If this is a custom normalizer, you probably forgot to register a corresponding custom serializer strategy with this serializer.", normalizer.getClass()));
    }

    private NormalizerSerializerStrategy getStrategy(Header header) throws Exception {
        if (header.normalizerType.equals(NormalizerType.CUSTOM)) {
            return (NormalizerSerializerStrategy) header.customStrategyClass.newInstance();
        }
        for (NormalizerSerializerStrategy normalizerSerializerStrategy : this.strategies) {
            if (strategySupportsNormalizer(normalizerSerializerStrategy, header.normalizerType, null)) {
                return normalizerSerializerStrategy;
            }
        }
        throw new RuntimeException("No serializer strategy found for given header " + header);
    }

    private boolean strategySupportsNormalizer(NormalizerSerializerStrategy normalizerSerializerStrategy, NormalizerType normalizerType, Class<? extends Normalizer> cls) {
        if (!normalizerSerializerStrategy.getSupportedType().equals(normalizerType)) {
            return false;
        }
        if (!normalizerSerializerStrategy.getSupportedType().equals(NormalizerType.CUSTOM)) {
            return true;
        }
        if (normalizerSerializerStrategy instanceof CustomSerializerStrategy) {
            return ((CustomSerializerStrategy) normalizerSerializerStrategy).getSupportedClass().equals(cls);
        }
        throw new IllegalArgumentException("Strategies supporting CUSTOM opType must be instance of CustomSerializerStrategy, got" + normalizerSerializerStrategy.getClass());
    }

    private Header parseHeader(InputStream inputStream) throws IOException, ClassNotFoundException {
        DataInputStream dataInputStream = new DataInputStream(inputStream);
        if (!dataInputStream.readUTF().equals(HEADER)) {
            throw new IllegalArgumentException("Could not restore normalizer: invalid header. If this normalizer was saved with a opType-specific strategy like StandardizeSerializerStrategy, use that class to restore it as well.");
        }
        int readInt = dataInputStream.readInt();
        if (readInt != 1) {
            throw new IllegalArgumentException("Could not restore normalizer: invalid version (" + readInt + ")");
        }
        NormalizerType valueOf = NormalizerType.valueOf(dataInputStream.readUTF());
        return valueOf.equals(NormalizerType.CUSTOM) ? new Header(valueOf, Class.forName(dataInputStream.readUTF())) : new Header(valueOf, null);
    }

    private void writeHeader(OutputStream outputStream, Header header) throws IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
        dataOutputStream.writeUTF(HEADER);
        dataOutputStream.writeInt(1);
        dataOutputStream.writeUTF(header.normalizerType.toString());
        if (header.customStrategyClass != null) {
            dataOutputStream.writeUTF(header.customStrategyClass.getName());
        }
    }
}
