package org.apache.submarine.server.model;

import javax.ws.rs.core.Response;
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.SubmitterManager;
import org.apache.submarine.server.api.Submitter;
import org.apache.submarine.server.api.model.ServeResponse;
import org.apache.submarine.server.api.model.ServeSpec;
import org.apache.submarine.server.api.proto.TritonModelConfig;
import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
import org.apache.submarine.server.model.database.service.ModelVersionService;
import org.apache.submarine.server.s3.Client;
import org.apache.submarine.server.s3.S3Constants;
import org.json.JSONArray;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/submarine/server/model/ModelManager.class */
public class ModelManager {
    private static final Logger LOG = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager manager;
    private final Submitter submitter;
    private final ModelVersionService modelVersionService;

    private ModelManager(Submitter submitter, ModelVersionService modelVersionService) {
        this.submitter = submitter;
        this.modelVersionService = modelVersionService;
    }

    public static ModelManager getInstance() {
        if (manager == null) {
            synchronized (ModelManager.class) {
                manager = new ModelManager(SubmitterManager.loadSubmitter(), new ModelVersionService());
            }
        }
        return manager;
    }

    public ServeResponse createServe(ServeSpec serveSpec) throws SubmarineRuntimeException {
        ModelVersionEntity select = this.modelVersionService.select(serveSpec.getModelName(), serveSpec.getModelVersion());
        setServeInfo(serveSpec, select);
        LOG.info("Create {} model serve.", serveSpec.getModelType());
        if (serveSpec.getModelType().equals("pytorch")) {
            transferDescription(serveSpec);
        }
        this.submitter.createServe(serveSpec);
        select.setCurrentStage("Production");
        this.modelVersionService.update(select);
        return getServeResponse(serveSpec);
    }

    public void deleteServe(ServeSpec serveSpec) throws SubmarineRuntimeException {
        ModelVersionEntity select = this.modelVersionService.select(serveSpec.getModelName(), serveSpec.getModelVersion());
        setServeInfo(serveSpec, select);
        LOG.info("Delete {} model serve", serveSpec.getModelType());
        this.submitter.deleteServe(serveSpec);
        select.setCurrentStage("None");
        this.modelVersionService.update(select);
    }

    private void checkServeSpec(ServeSpec serveSpec) throws SubmarineRuntimeException {
        if (serveSpec == null) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Serve Spec object is null.");
        }
        if (serveSpec.getModelName() == null) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Model name in Serve Spec is null.");
        }
        Integer modelVersion = serveSpec.getModelVersion();
        if (modelVersion == null || modelVersion.intValue() <= 0) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Model version must be positive, but get " + modelVersion);
        }
    }

    private void setServeInfo(ServeSpec serveSpec, ModelVersionEntity modelVersionEntity) {
        checkServeSpec(serveSpec);
        String modelType = modelVersionEntity.getModelType();
        String id = modelVersionEntity.getId();
        serveSpec.setModelType(modelType);
        serveSpec.setModelId(id);
        String format = String.format("%s-%d-%s", serveSpec.getModelName(), serveSpec.getModelVersion(), id);
        if (serveSpec.getModelType().equals("pytorch")) {
            serveSpec.setModelURI(String.format("s3://%s/registry/%s", S3Constants.BUCKET, format));
        } else {
            if (!serveSpec.getModelType().equals("tensorflow")) {
                throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), String.format("Unexpected model type: %s", modelType));
            }
            serveSpec.setModelURI(String.format("s3://%s/registry/%s/%s", S3Constants.BUCKET, format, serveSpec.getModelName()));
        }
    }

    private void transferDescription(ServeSpec serveSpec) {
        Client client = new Client();
        String format = String.format("%s-%d-%s", serveSpec.getModelName(), serveSpec.getModelVersion(), serveSpec.getModelId());
        JSONObject jSONObject = new JSONObject(new String(client.downloadArtifact(String.format("registry/%s/%s/%d/description.json", format, serveSpec.getModelName(), serveSpec.getModelVersion()))));
        TritonModelConfig.ModelConfig.Builder newBuilder = TritonModelConfig.ModelConfig.newBuilder();
        newBuilder.setPlatform("pytorch_libtorch");
        JSONArray jSONArray = (JSONArray) jSONObject.get("input");
        for (int i = 0; i < jSONArray.length(); i++) {
            JSONArray jSONArray2 = (JSONArray) ((JSONObject) jSONArray.get(i)).get("dims");
            TritonModelConfig.ModelInput.Builder newBuilder2 = TritonModelConfig.ModelInput.newBuilder();
            newBuilder2.setName("INPUT__" + i);
            newBuilder2.setDataType(TritonModelConfig.DataType.valueOf("TYPE_FP32"));
            jSONArray2.forEach(obj -> {
                newBuilder2.addDims(((Integer) obj).intValue());
            });
            newBuilder.addInput(newBuilder2);
        }
        JSONArray jSONArray3 = (JSONArray) jSONObject.get("output");
        for (int i2 = 0; i2 < jSONArray3.length(); i2++) {
            JSONArray jSONArray4 = (JSONArray) ((JSONObject) jSONArray3.get(i2)).get("dims");
            TritonModelConfig.ModelOutput.Builder newBuilder3 = TritonModelConfig.ModelOutput.newBuilder();
            newBuilder3.setName("OUTPUT__" + i2);
            newBuilder3.setDataType(TritonModelConfig.DataType.valueOf("TYPE_FP32"));
            jSONArray4.forEach(obj2 -> {
                newBuilder3.addDims(((Integer) obj2).intValue());
            });
            newBuilder.addOutput(newBuilder3);
        }
        client.logArtifact(String.format("registry/%s/%s/config.pbtxt", format, serveSpec.getModelName()), newBuilder.toString().getBytes());
    }

    private ServeResponse getServeResponse(ServeSpec serveSpec) {
        ServeResponse serveResponse = new ServeResponse();
        if (serveSpec.getModelType().equals("pytorch")) {
            serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/v2/models/%s/infer", serveSpec.getModelName(), serveSpec.getModelVersion(), serveSpec.getModelName()));
        } else {
            serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/api/v1.0/predictions", serveSpec.getModelName(), serveSpec.getModelVersion()));
        }
        return serveResponse;
    }
}
