package ai.djl.translate;

import ai.djl.Model;
import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory;
import ai.djl.ndarray.NDList;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/translate/DefaultTranslatorFactory.class */
public class DefaultTranslatorFactory implements TranslatorFactory {
    protected Map<Pair<Type, Type>, Translator<?, ?>> translators;
    private ServingTranslatorFactory servingTranslatorFactory = new ServingTranslatorFactory();
    private ImageClassificationTranslatorFactory imageClassificationTranslatorFactory = new ImageClassificationTranslatorFactory();

    public <I, O> void registerTranslator(Class<I> cls, Class<O> cls2, Translator<I, O> translator) {
        if (this.translators == null) {
            this.translators = new ConcurrentHashMap();
        }
        this.translators.put(new Pair<>(cls, cls2), translator);
    }

    @Override // ai.djl.translate.TranslatorFactory
    public Set<Pair<Type, Type>> getSupportedTypes() {
        HashSet hashSet = new HashSet();
        if (this.translators != null) {
            hashSet.addAll(this.translators.keySet());
        }
        hashSet.add(new Pair(NDList.class, NDList.class));
        return hashSet;
    }

    @Override // ai.djl.translate.TranslatorFactory
    public boolean isSupported(Class<?> cls, Class<?> cls2) {
        if (cls == NDList.class && cls2 == NDList.class) {
            return true;
        }
        return (this.translators != null && this.translators.containsKey(new Pair(cls, cls2))) || this.servingTranslatorFactory.isSupported(cls, cls2) || this.imageClassificationTranslatorFactory.isSupported(cls, cls2);
    }

    @Override // ai.djl.translate.TranslatorFactory
    public Translator<?, ?> newInstance(Class<?> cls, Class<?> cls2, Model model, Map<String, ?> map) throws TranslateException {
        Translator<?, ?> translator;
        if (this.translators != null && (translator = this.translators.get(new Pair(cls, cls2))) != null) {
            return translator;
        }
        if (cls == NDList.class && cls2 == NDList.class) {
            return new NoopTranslator();
        }
        if (this.servingTranslatorFactory.isSupported(cls, cls2)) {
            return this.servingTranslatorFactory.newInstance(cls, cls2, model, map);
        }
        if (this.imageClassificationTranslatorFactory.isSupported(cls, cls2)) {
            return this.imageClassificationTranslatorFactory.newInstance(cls, cls2, model, map);
        }
        return null;
    }
}
