package ai.djl.translate;

import ai.djl.Model;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/translate/ExpansionTranslatorFactory.class */
public abstract class ExpansionTranslatorFactory<IbaseT, ObaseT> implements TranslatorFactory {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/djl/translate/ExpansionTranslatorFactory$ExpandedTranslatorOptions.class */
    public final class ExpandedTranslatorOptions implements TranslatorOptions {
        private Translator<IbaseT, ObaseT> translator;

        private ExpandedTranslatorOptions(Translator<IbaseT, ObaseT> translator) {
            this.translator = translator;
        }

        @Override // ai.djl.translate.TranslatorOptions
        public Set<Pair<Type, Type>> getOptions() {
            return ExpansionTranslatorFactory.this.getSupportedTypes();
        }

        @Override // ai.djl.translate.TranslatorOptions
        public <I, O> Translator<I, O> option(Class<I> cls, Class<O> cls2) {
            return ExpansionTranslatorFactory.this.newInstance(cls, cls2, this.translator);
        }
    }

    @Override // ai.djl.translate.TranslatorFactory
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return getExpansions().keySet();
    }

    @Override // ai.djl.translate.TranslatorFactory
    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) {
        return newInstance(cls, cls2, buildBaseTranslator(model, map));
    }

    <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Translator<IbaseT, ObaseT> translator) {
        Function<Translator<IbaseT, ObaseT>, Translator<?, ?>> function = getExpansions().get(new Pair(cls, cls2));
        if (function == null) {
            throw new IllegalArgumentException("Unsupported expansion input/output types.");
        }
        return (Translator) function.apply(translator);
    }

    public ExpansionTranslatorFactory<IbaseT, ObaseT>.ExpandedTranslatorOptions withTranslator(Translator<IbaseT, ObaseT> translator) {
        return new ExpandedTranslatorOptions(translator);
    }

    protected abstract Translator<IbaseT, ObaseT> buildBaseTranslator(Model model, Map<String, ?> map);

    protected abstract Map<Pair<Type, Type>, Function<Translator<IbaseT, ObaseT>, Translator<?, ?>>> getExpansions();
}
