/*
 * Decompiled with CFR 0.152.
 */
package org.openmetadata.service.jdbi3;

import java.util.ArrayList;
import java.util.List;
import org.jdbi.v3.sqlobject.transaction.Transaction;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.feed.ResolveTask;
import org.openmetadata.schema.entity.data.MlModel;
import org.openmetadata.schema.entity.services.MlModelService;
import org.openmetadata.schema.type.EntityReference;
import org.openmetadata.schema.type.Include;
import org.openmetadata.schema.type.MlFeature;
import org.openmetadata.schema.type.MlFeatureSource;
import org.openmetadata.schema.type.Relationship;
import org.openmetadata.schema.type.TagLabel;
import org.openmetadata.schema.type.TaskType;
import org.openmetadata.service.Entity;
import org.openmetadata.service.exception.CatalogExceptionMessage;
import org.openmetadata.service.jdbi3.EntityRepository;
import org.openmetadata.service.jdbi3.FeedRepository;
import org.openmetadata.service.resources.feeds.MessageParser;
import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.FullyQualifiedName;
import org.openmetadata.service.util.JsonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MlModelRepository
extends EntityRepository<MlModel> {
    private static final Logger LOG = LoggerFactory.getLogger(MlModelRepository.class);
    private static final String MODEL_UPDATE_FIELDS = "dashboard";
    private static final String MODEL_PATCH_FIELDS = "dashboard";

    public MlModelRepository() {
        super("v1/mlmodels/", "mlmodel", MlModel.class, Entity.getCollectionDAO().mlModelDAO(), "dashboard", "dashboard");
        this.supportsSearch = true;
    }

    public static MlFeature findMlFeature(List<MlFeature> features, String featureName) {
        return features.stream().filter(c -> c.getName().equals(featureName)).findFirst().orElseThrow(() -> new IllegalArgumentException(CatalogExceptionMessage.invalidFieldName("mlFeature", featureName)));
    }

    @Override
    public void setFullyQualifiedName(MlModel mlModel) {
        mlModel.setFullyQualifiedName(FullyQualifiedName.add(mlModel.getService().getFullyQualifiedName(), mlModel.getName()));
        if (!CommonUtil.nullOrEmpty((List)mlModel.getMlFeatures())) {
            this.setMlFeatureFQN(mlModel.getFullyQualifiedName(), mlModel.getMlFeatures());
        }
    }

    @Override
    public MlModel setFields(MlModel mlModel, EntityUtil.Fields fields) {
        mlModel.setService(this.getContainer(mlModel.getId()));
        mlModel.setDashboard(fields.contains("dashboard") ? this.getDashboard(mlModel) : mlModel.getDashboard());
        if (mlModel.getUsageSummary() == null) {
            mlModel.withUsageSummary(fields.contains("usageSummary") ? EntityUtil.getLatestUsage(this.daoCollection.usageDAO(), mlModel.getId()) : mlModel.getUsageSummary());
        }
        return mlModel;
    }

    @Override
    public MlModel clearFields(MlModel mlModel, EntityUtil.Fields fields) {
        mlModel.setDashboard(fields.contains("dashboard") ? mlModel.getDashboard() : null);
        return mlModel.withUsageSummary(fields.contains("usageSummary") ? mlModel.getUsageSummary() : null);
    }

    @Override
    public void restorePatchAttributes(MlModel original, MlModel updated) {
        updated.withFullyQualifiedName(original.getFullyQualifiedName()).withService(original.getService()).withName(original.getName()).withId(original.getId());
    }

    private void setMlFeatureSourcesFQN(List<MlFeatureSource> mlSources) {
        mlSources.forEach(s -> {
            if (s.getDataSource() != null) {
                s.setFullyQualifiedName(FullyQualifiedName.add(s.getDataSource().getFullyQualifiedName(), s.getName()));
            } else {
                s.setFullyQualifiedName(s.getName());
            }
        });
    }

    private void setMlFeatureFQN(String parentFQN, List<MlFeature> mlFeatures) {
        mlFeatures.forEach(f -> {
            String featureFqn = FullyQualifiedName.add(parentFQN, f.getName());
            f.setFullyQualifiedName(featureFqn);
            if (f.getFeatureSources() != null) {
                this.setMlFeatureSourcesFQN(f.getFeatureSources());
            }
        });
    }

    private void validateReferences(List<MlFeature> mlFeatures) {
        for (MlFeature feature : mlFeatures) {
            if (CommonUtil.nullOrEmpty((List)feature.getFeatureSources())) continue;
            for (MlFeatureSource source : feature.getFeatureSources()) {
                this.validateMlDataSource(source);
            }
        }
    }

    private void validateMlDataSource(MlFeatureSource source) {
        if (source.getDataSource() != null) {
            Entity.getEntityReferenceById(source.getDataSource().getType(), source.getDataSource().getId(), Include.NON_DELETED);
        }
    }

    @Override
    public void prepare(MlModel mlModel, boolean update) {
        this.populateService(mlModel);
        if (!CommonUtil.nullOrEmpty((List)mlModel.getMlFeatures())) {
            this.validateReferences(mlModel.getMlFeatures());
            mlModel.getMlFeatures().forEach(feature -> this.checkMutuallyExclusive(feature.getTags()));
        }
        if (mlModel.getDashboard() != null) {
            mlModel.setDashboard(Entity.getEntityReference(mlModel.getDashboard(), Include.NON_DELETED));
        }
    }

    @Override
    public void storeEntity(MlModel mlModel, boolean update) {
        EntityReference dashboard = mlModel.getDashboard();
        EntityReference service = mlModel.getService();
        mlModel.withService(null).withDashboard(null);
        this.store(mlModel, update);
        mlModel.withService(service).withDashboard(dashboard);
    }

    @Override
    public void storeRelationships(MlModel mlModel) {
        EntityReference service = mlModel.getService();
        this.addRelationship(service.getId(), mlModel.getId(), service.getType(), "mlmodel", Relationship.CONTAINS);
        this.setDashboard(mlModel, mlModel.getDashboard());
        if (mlModel.getDashboard() != null) {
            this.addRelationship(mlModel.getId(), mlModel.getDashboard().getId(), "mlmodel", "dashboard", Relationship.USES);
        }
        this.setMlFeatureSourcesLineage(mlModel);
    }

    @Override
    public MlModel setInheritedFields(MlModel mlModel, EntityUtil.Fields fields) {
        MlModelService service = (MlModelService)Entity.getEntity("mlmodelService", mlModel.getService().getId(), "domain", Include.ALL);
        return this.inheritDomain(mlModel, fields, (EntityInterface)service);
    }

    private void setMlFeatureSourcesLineage(MlModel mlModel) {
        if (mlModel.getMlFeatures() != null) {
            mlModel.getMlFeatures().forEach(mlFeature -> {
                if (mlFeature.getFeatureSources() != null) {
                    mlFeature.getFeatureSources().forEach(mlFeatureSource -> {
                        EntityReference targetEntity = mlFeatureSource.getDataSource();
                        if (targetEntity != null) {
                            this.addRelationship(targetEntity.getId(), mlModel.getId(), targetEntity.getType(), "mlmodel", Relationship.UPSTREAM);
                        }
                    });
                }
            });
        }
    }

    @Override
    public EntityRepository.EntityUpdater getUpdater(MlModel original, MlModel updated, EntityRepository.Operation operation) {
        return new MlModelUpdater(original, updated, operation);
    }

    @Override
    public EntityInterface getParentEntity(MlModel entity, String fields) {
        return (EntityInterface)Entity.getEntity(entity.getService(), fields, Include.NON_DELETED);
    }

    @Override
    public List<TagLabel> getAllTags(EntityInterface entity) {
        ArrayList<TagLabel> allTags = new ArrayList<TagLabel>();
        MlModel mlModel = (MlModel)entity;
        EntityUtil.mergeTags(allTags, mlModel.getTags());
        for (MlFeature feature : CommonUtil.listOrEmpty((List)mlModel.getMlFeatures())) {
            EntityUtil.mergeTags(allTags, feature.getTags());
            for (MlFeatureSource source : CommonUtil.listOrEmpty((List)feature.getFeatureSources())) {
                EntityUtil.mergeTags(allTags, source.getTags());
            }
        }
        return allTags;
    }

    @Override
    public FeedRepository.TaskWorkflow getTaskWorkflow(FeedRepository.ThreadContext threadContext) {
        this.validateTaskThread(threadContext);
        MessageParser.EntityLink entityLink = threadContext.getAbout();
        if (entityLink.getFieldName().equals("mlFeatures")) {
            TaskType taskType = threadContext.getThread().getTask().getType();
            if (EntityUtil.isDescriptionTask(taskType)) {
                return new MlFeatureDescriptionTaskWorkflow(threadContext);
            }
            if (EntityUtil.isTagTask(taskType)) {
                return new MlFeatureTagTaskWorkflow(threadContext);
            }
            throw new IllegalArgumentException(String.format("Invalid task type %s", taskType));
        }
        return super.getTaskWorkflow(threadContext);
    }

    private void populateService(MlModel mlModel) {
        MlModelService service = (MlModelService)Entity.getEntity(mlModel.getService(), "", Include.NON_DELETED);
        mlModel.setService(service.getEntityReference());
        mlModel.setServiceType(service.getServiceType());
    }

    private EntityReference getDashboard(MlModel mlModel) {
        return mlModel == null ? null : this.getToEntityRef(mlModel.getId(), Relationship.USES, "dashboard", false);
    }

    public void setDashboard(MlModel mlModel, EntityReference dashboard) {
        if (dashboard != null) {
            this.addRelationship(mlModel.getId(), mlModel.getDashboard().getId(), "mlmodel", "dashboard", Relationship.USES);
        }
    }

    public class MlModelUpdater
    extends EntityRepository.EntityUpdater {
        public MlModelUpdater(MlModel original, MlModel updated, EntityRepository.Operation operation) {
            super((EntityRepository)MlModelRepository.this, (EntityInterface)original, (EntityInterface)updated, operation);
        }

        @Override
        @Transaction
        public void entitySpecificUpdate() {
            this.updateAlgorithm((MlModel)this.original, (MlModel)this.updated);
            this.updateDashboard((MlModel)this.original, (MlModel)this.updated);
            this.updateMlFeatures((MlModel)this.original, (MlModel)this.updated);
            this.updateMlHyperParameters((MlModel)this.original, (MlModel)this.updated);
            this.updateMlStore((MlModel)this.original, (MlModel)this.updated);
            this.updateServer((MlModel)this.original, (MlModel)this.updated);
            this.updateTarget((MlModel)this.original, (MlModel)this.updated);
            this.recordChange("sourceUrl", ((MlModel)this.original).getSourceUrl(), ((MlModel)this.updated).getSourceUrl());
        }

        private void updateAlgorithm(MlModel origModel, MlModel updatedModel) {
            if (((MlModel)this.updated).getAlgorithm() != null && this.recordChange("algorithm", origModel.getAlgorithm(), updatedModel.getAlgorithm())) {
                this.majorVersionChange = true;
            }
        }

        private void updateMlFeatures(MlModel origModel, MlModel updatedModel) {
            ArrayList addedList = new ArrayList();
            ArrayList deletedList = new ArrayList();
            this.recordListChange("mlFeatures", origModel.getMlFeatures(), updatedModel.getMlFeatures(), addedList, deletedList, EntityUtil.mlFeatureMatch);
        }

        private void updateMlHyperParameters(MlModel origModel, MlModel updatedModel) {
            ArrayList addedList = new ArrayList();
            ArrayList deletedList = new ArrayList();
            this.recordListChange("mlHyperParameters", origModel.getMlHyperParameters(), updatedModel.getMlHyperParameters(), addedList, deletedList, EntityUtil.mlHyperParameterMatch);
        }

        private void updateMlStore(MlModel origModel, MlModel updatedModel) {
            this.recordChange("mlStore", origModel.getMlStore(), updatedModel.getMlStore(), true);
        }

        private void updateServer(MlModel origModel, MlModel updatedModel) {
            if (this.recordChange("server", origModel.getServer(), updatedModel.getServer())) {
                this.majorVersionChange = true;
            }
        }

        private void updateTarget(MlModel origModel, MlModel updatedModel) {
            if (this.recordChange("target", origModel.getTarget(), updatedModel.getTarget())) {
                this.majorVersionChange = true;
            }
        }

        private void updateDashboard(MlModel origModel, MlModel updatedModel) {
            EntityReference updatedDashboard;
            EntityReference origDashboard = origModel.getDashboard();
            if (this.recordChange("dashboard", origDashboard, updatedDashboard = updatedModel.getDashboard(), true, EntityUtil.entityReferenceMatch)) {
                if (origModel.getDashboard() != null) {
                    MlModelRepository.this.deleteTo(updatedModel.getId(), "mlmodel", Relationship.USES, "dashboard");
                }
                if (updatedDashboard != null) {
                    MlModelRepository.this.addRelationship(updatedModel.getId(), updatedDashboard.getId(), "mlmodel", "dashboard", Relationship.USES);
                }
            }
        }
    }

    static class MlFeatureDescriptionTaskWorkflow
    extends EntityRepository.DescriptionTaskWorkflow {
        private final MlFeature mlFeature;

        MlFeatureDescriptionTaskWorkflow(FeedRepository.ThreadContext threadContext) {
            super(threadContext);
            MlModel mlModel = (MlModel)threadContext.getAboutEntity();
            this.mlFeature = MlModelRepository.findMlFeature(mlModel.getMlFeatures(), threadContext.getAbout().getArrayFieldName());
        }

        @Override
        public EntityInterface performTask(String user, ResolveTask resolveTask) {
            this.mlFeature.setDescription(resolveTask.getNewValue());
            return this.threadContext.getAboutEntity();
        }
    }

    static class MlFeatureTagTaskWorkflow
    extends EntityRepository.TagTaskWorkflow {
        private final MlFeature mlFeature;

        MlFeatureTagTaskWorkflow(FeedRepository.ThreadContext threadContext) {
            super(threadContext);
            MlModel mlModel = (MlModel)threadContext.getAboutEntity();
            this.mlFeature = MlModelRepository.findMlFeature(mlModel.getMlFeatures(), threadContext.getAbout().getArrayFieldName());
        }

        @Override
        public EntityInterface performTask(String user, ResolveTask resolveTask) {
            List<TagLabel> tags = JsonUtils.readObjects(resolveTask.getNewValue(), TagLabel.class);
            this.mlFeature.setTags(tags);
            return this.threadContext.getAboutEntity();
        }
    }
}

