/*
 * Decompiled with CFR 0.152.
 */
package org.apache.submarine.server.model;

import javax.ws.rs.core.Response;
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.SubmitterManager;
import org.apache.submarine.server.api.Submitter;
import org.apache.submarine.server.api.model.ServeResponse;
import org.apache.submarine.server.api.model.ServeSpec;
import org.apache.submarine.server.api.proto.TritonModelConfig;
import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
import org.apache.submarine.server.model.database.service.ModelVersionService;
import org.apache.submarine.server.s3.Client;
import org.json.JSONArray;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelManager {
    private static final Logger LOG = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager manager;
    private final Submitter submitter;
    private final ModelVersionService modelVersionService;

    private ModelManager(Submitter submitter, ModelVersionService modelVersionService) {
        this.submitter = submitter;
        this.modelVersionService = modelVersionService;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static ModelManager getInstance() {
        if (manager != null) return manager;
        Class<ModelManager> clazz = ModelManager.class;
        synchronized (ModelManager.class) {
            manager = new ModelManager(SubmitterManager.loadSubmitter(), new ModelVersionService());
            // ** MonitorExit[var0] (shouldn't be in output)
            return manager;
        }
    }

    public ServeResponse createServe(ServeSpec spec) throws SubmarineRuntimeException {
        ModelVersionEntity modelVersion = this.modelVersionService.select(spec.getModelName(), spec.getModelVersion());
        this.setServeInfo(spec, modelVersion);
        LOG.info("Create {} model serve.", (Object)spec.getModelType());
        if (spec.getModelType().equals("pytorch")) {
            this.transferDescription(spec);
        }
        this.submitter.createServe(spec);
        modelVersion.setCurrentStage("Production");
        this.modelVersionService.update(modelVersion);
        return this.getServeResponse(spec);
    }

    public void deleteServe(ServeSpec spec) throws SubmarineRuntimeException {
        ModelVersionEntity modelVersion = this.modelVersionService.select(spec.getModelName(), spec.getModelVersion());
        this.setServeInfo(spec, modelVersion);
        LOG.info("Delete {} model serve", (Object)spec.getModelType());
        this.submitter.deleteServe(spec);
        modelVersion.setCurrentStage("None");
        this.modelVersionService.update(modelVersion);
    }

    private void checkServeSpec(ServeSpec spec) throws SubmarineRuntimeException {
        if (spec == null) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Serve Spec object is null.");
        }
        if (spec.getModelName() == null) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Model name in Serve Spec is null.");
        }
        Integer modelVersion = spec.getModelVersion();
        if (modelVersion == null || modelVersion <= 0) {
            throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(), "Invalid. Model version must be positive, but get " + modelVersion);
        }
    }

    private void setServeInfo(ServeSpec spec, ModelVersionEntity modelVersion) {
        this.checkServeSpec(spec);
        String modelType = modelVersion.getModelType();
        String modelId = modelVersion.getId();
        spec.setModelType(modelType);
        spec.setModelId(modelId);
        String modelUniquePath = String.format("%s-%d-%s", spec.getModelName(), spec.getModelVersion(), modelId);
        if (spec.getModelType().equals("pytorch")) {
            spec.setModelURI(String.format("s3://%s/registry/%s", "submarine", modelUniquePath));
        } else if (spec.getModelType().equals("tensorflow")) {
            spec.setModelURI(String.format("s3://%s/registry/%s/%s", "submarine", modelUniquePath, spec.getModelName()));
        } else {
            throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), String.format("Unexpected model type: %s", modelType));
        }
    }

    private void transferDescription(ServeSpec spec) {
        Client s3Client = new Client();
        String modelUniquePath = String.format("%s-%d-%s", spec.getModelName(), spec.getModelVersion(), spec.getModelId());
        String res = new String(s3Client.downloadArtifact(String.format("registry/%s/%s/%d/description.json", modelUniquePath, spec.getModelName(), spec.getModelVersion())));
        JSONObject description = new JSONObject(res);
        TritonModelConfig.ModelConfig.Builder modelConfig = TritonModelConfig.ModelConfig.newBuilder();
        modelConfig.setPlatform("pytorch_libtorch");
        JSONArray inputs = (JSONArray)description.get("input");
        for (int idx = 0; idx < inputs.length(); ++idx) {
            JSONArray dims = (JSONArray)((JSONObject)inputs.get(idx)).get("dims");
            TritonModelConfig.ModelInput.Builder modelInput = TritonModelConfig.ModelInput.newBuilder();
            modelInput.setName("INPUT__" + idx);
            modelInput.setDataType(TritonModelConfig.DataType.valueOf((String)"TYPE_FP32"));
            dims.forEach(dim -> modelInput.addDims((long)((Integer)dim).intValue()));
            modelConfig.addInput(modelInput);
        }
        JSONArray outputs = (JSONArray)description.get("output");
        for (int idx = 0; idx < outputs.length(); ++idx) {
            JSONArray dims = (JSONArray)((JSONObject)outputs.get(idx)).get("dims");
            TritonModelConfig.ModelOutput.Builder modelOutput = TritonModelConfig.ModelOutput.newBuilder();
            modelOutput.setName("OUTPUT__" + idx);
            modelOutput.setDataType(TritonModelConfig.DataType.valueOf((String)"TYPE_FP32"));
            dims.forEach(dim -> modelOutput.addDims((long)((Integer)dim).intValue()));
            modelConfig.addOutput(modelOutput);
        }
        s3Client.logArtifact(String.format("registry/%s/%s/config.pbtxt", modelUniquePath, spec.getModelName()), modelConfig.toString().getBytes());
    }

    private ServeResponse getServeResponse(ServeSpec spec) {
        ServeResponse serveResponse = new ServeResponse();
        if (spec.getModelType().equals("pytorch")) {
            serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/v2/models/%s/infer", spec.getModelName(), spec.getModelVersion(), spec.getModelName()));
        } else {
            serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/api/v1.0/predictions", spec.getModelName(), spec.getModelVersion()));
        }
        return serveResponse;
    }
}

