/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.neo4j;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.cypherdsl.support.schema_name.SchemaNames;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Record;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;
import org.neo4j.driver.summary.ResultSummary;
import org.neo4j.driver.types.Node;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.neo4j.filter.Neo4jVectorFilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class Neo4jVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(Neo4jVectorStore.class);
    public static final int DEFAULT_EMBEDDING_DIMENSION = 1536;
    public static final int DEFAULT_TRANSACTION_SIZE = 10000;
    public static final String DEFAULT_LABEL = "Document";
    public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index";
    public static final String DEFAULT_EMBEDDING_PROPERTY = "embedding";
    public static final String DEFAULT_ID_PROPERTY = "id";
    public static final String DEFAULT_TEXT_PROPERTY = "text";
    public static final String DEFAULT_CONSTRAINT_NAME = "Document_unique_idx";
    private static final Map<Neo4jDistanceType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN);
    private final Driver driver;
    private final SessionConfig sessionConfig;
    private final int embeddingDimension;
    private final Neo4jDistanceType distanceType;
    private final String embeddingProperty;
    private final String label;
    private final String indexName;
    private final String indexNameNotSanitized;
    private final String idProperty;
    private final String textProperty;
    private final String constraintName;
    private final Neo4jVectorFilterExpressionConverter filterExpressionConverter = new Neo4jVectorFilterExpressionConverter();
    private final boolean initializeSchema;

    protected Neo4jVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.driver, (String)"Neo4j driver must not be null");
        this.driver = builder.driver;
        this.sessionConfig = builder.sessionConfig;
        this.embeddingDimension = builder.embeddingDimension;
        this.distanceType = builder.distanceType;
        this.embeddingProperty = (String)SchemaNames.sanitize((String)builder.embeddingProperty).orElseThrow();
        this.label = (String)SchemaNames.sanitize((String)builder.label).orElseThrow();
        this.indexNameNotSanitized = builder.indexName;
        this.indexName = (String)SchemaNames.sanitize((String)builder.indexName, (boolean)true).orElseThrow();
        this.idProperty = (String)SchemaNames.sanitize((String)builder.idProperty).orElseThrow();
        this.textProperty = (String)SchemaNames.sanitize((String)builder.textProperty).orElseThrow();
        this.constraintName = (String)SchemaNames.sanitize((String)builder.constraintName).orElseThrow();
        this.initializeSchema = builder.initializeSchema;
    }

    public void doAdd(List<Document> documents) {
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        List<Map> rows = documents.stream().map(document -> this.documentToRecord((Document)document, (float[])embeddings.get(documents.indexOf(document)))).toList();
        try (Session session = this.driver.session();){
            String statement = "\tUNWIND $rows AS row\n\tMERGE (u:%s {%2$s: row.id})\n\t\tSET u += row.properties\n\tWITH row, u\n\tCALL db.create.setNodeVectorProperty(u, $embeddingProperty, row[$embeddingProperty])\n".formatted(this.label, this.idProperty);
            session.executeWrite(tx -> tx.run(statement, Map.of("rows", rows, "embeddingProperty", this.embeddingProperty)).consume());
        }
    }

    public void doDelete(List<String> idList) {
        try (Session session = this.driver.session(this.sessionConfig);){
            session.run("MATCH (n:%s) WHERE n.%s IN $ids\nCALL { WITH n DETACH DELETE n } IN TRANSACTIONS OF $transactionSize ROWS\n".formatted(this.label, this.idProperty), Map.of("ids", idList, "transactionSize", 10000)).consume();
        }
    }

    protected void doDelete(Filter.Expression filterExpression) {
        Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
        try (Session session = this.driver.session(this.sessionConfig);){
            String whereClause = this.filterExpressionConverter.convertExpression(filterExpression);
            String cypher = "MATCH (node:%s) WHERE %s\nCALL { WITH node DETACH DELETE node } IN TRANSACTIONS OF $transactionSize ROWS\n".formatted(this.label, whereClause);
            ResultSummary summary = session.run(cypher, Map.of("transactionSize", 10000)).consume();
            logger.debug("Deleted {} nodes matching filter expression", (Object)summary.counters().nodesDeleted());
        }
        catch (Exception e) {
            logger.error("Failed to delete nodes by filter: {}", (Object)e.getMessage(), (Object)e);
            throw new IllegalStateException("Failed to delete nodes by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        Assert.isTrue((request.getTopK() > 0 ? 1 : 0) != 0, (String)"The number of documents to returned must be greater than zero");
        Assert.isTrue((request.getSimilarityThreshold() >= 0.0 && request.getSimilarityThreshold() <= 1.0 ? 1 : 0) != 0, (String)"The similarity score is bounded between 0 and 1; least to most similar respectively.");
        Value embedding = Values.value((float[])this.embeddingModel.embed(request.getQuery()));
        try (Session session = this.driver.session(this.sessionConfig);){
            StringBuilder condition = new StringBuilder("score >= $threshold");
            if (request.hasFilterExpression()) {
                condition.append(" AND ").append(this.filterExpressionConverter.convertExpression(request.getFilterExpression()));
            }
            String query = "CALL db.index.vector.queryNodes($indexName, $numberOfNearestNeighbours, $embeddingValue)\nYIELD node, score\nWHERE %s\nRETURN node, score".formatted(condition);
            List list = (List)session.executeRead(tx -> tx.run(query, Map.of("indexName", this.indexNameNotSanitized, "numberOfNearestNeighbours", request.getTopK(), "embeddingValue", embedding, "threshold", request.getSimilarityThreshold())).list(this::recordToDocument));
            return list;
        }
    }

    public void afterPropertiesSet() {
        if (!this.initializeSchema) {
            return;
        }
        try (Session session = this.driver.session(this.sessionConfig);){
            session.executeWriteWithoutResult(tx -> {
                tx.run("CREATE CONSTRAINT %s IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE".formatted(this.constraintName, this.label, this.idProperty)).consume();
                String statement = "CREATE VECTOR INDEX %s IF NOT EXISTS FOR (n:%s) ON (n.%s)\n\t\tOPTIONS {indexConfig: {\n\t\t`vector.dimensions`: %d,\n\t\t`vector.similarity_function`: '%s'\n\t\t}}\n".formatted(this.indexName, this.label, this.embeddingProperty, this.embeddingDimension, this.distanceType.name);
                tx.run(statement).consume();
            });
            session.run("CALL db.awaitIndexes()").consume();
        }
    }

    private Map<String, Object> documentToRecord(Document document, float[] embedding) {
        HashMap<String, Object> row = new HashMap<String, Object>();
        row.put(DEFAULT_ID_PROPERTY, document.getId());
        HashMap<String, String> properties = new HashMap<String, String>();
        properties.put(this.textProperty, document.getText());
        document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, (String)Values.value((Object)v)));
        row.put("properties", properties);
        row.put(this.embeddingProperty, Values.value((float[])embedding));
        return row;
    }

    private Document recordToDocument(Record neoRecord) {
        Node node = neoRecord.get("node").asNode();
        float score = neoRecord.get("score").asFloat();
        HashMap<String, Float> metaData = new HashMap<String, Float>();
        metaData.put(DocumentMetadata.DISTANCE.value(), Float.valueOf(1.0f - score));
        node.keys().forEach(key -> {
            if (key.startsWith("metadata.")) {
                metaData.put(key.substring(key.indexOf(".") + 1), (Float)node.get(key).asObject());
            }
        });
        return Document.builder().id(node.get(this.idProperty).asString()).text(node.get(this.textProperty).asString()).metadata(Map.copyOf(metaData)).score(Double.valueOf(score)).build();
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.NEO4J.value(), (String)operationName).collectionName(this.indexName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).similarityMetric(this.getSimilarityMetric());
    }

    private String getSimilarityMetric() {
        if (!SIMILARITY_TYPE_MAPPING.containsKey((Object)this.distanceType)) {
            return this.distanceType.name();
        }
        return SIMILARITY_TYPE_MAPPING.get((Object)this.distanceType).value();
    }

    public <T> Optional<T> getNativeClient() {
        Driver client = this.driver;
        return Optional.of(client);
    }

    public static Builder builder(Driver driver, EmbeddingModel embeddingModel) {
        return new Builder(driver, embeddingModel);
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final Driver driver;
        private SessionConfig sessionConfig = SessionConfig.defaultConfig();
        private int embeddingDimension = 1536;
        private Neo4jDistanceType distanceType = Neo4jDistanceType.COSINE;
        private String label = "Document";
        private String embeddingProperty = "embedding";
        private String indexName = "spring-ai-document-index";
        private String idProperty = "id";
        private String textProperty = "text";
        private String constraintName = "Document_unique_idx";
        private boolean initializeSchema = false;

        private Builder(Driver driver, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)driver, (String)"Neo4j driver must not be null");
            this.driver = driver;
        }

        public Builder databaseName(String databaseName) {
            if (StringUtils.hasText((String)databaseName)) {
                this.sessionConfig = SessionConfig.forDatabase((String)databaseName);
            }
            return this;
        }

        public Builder sessionConfig(SessionConfig sessionConfig) {
            this.sessionConfig = sessionConfig;
            return this;
        }

        public Builder embeddingDimension(int dimension) {
            Assert.isTrue((dimension >= 1 ? 1 : 0) != 0, (String)"Dimension has to be positive");
            this.embeddingDimension = dimension;
            return this;
        }

        public Builder distanceType(Neo4jDistanceType distanceType) {
            Assert.notNull((Object)((Object)distanceType), (String)"Distance type may not be null");
            this.distanceType = distanceType;
            return this;
        }

        public Builder label(String label) {
            if (StringUtils.hasText((String)label)) {
                this.label = label;
            }
            return this;
        }

        public Builder embeddingProperty(String embeddingProperty) {
            if (StringUtils.hasText((String)embeddingProperty)) {
                this.embeddingProperty = embeddingProperty;
            }
            return this;
        }

        public Builder indexName(String indexName) {
            if (StringUtils.hasText((String)indexName)) {
                this.indexName = indexName;
            }
            return this;
        }

        public Builder idProperty(String idProperty) {
            if (StringUtils.hasText((String)idProperty)) {
                this.idProperty = idProperty;
            }
            return this;
        }

        public Builder textProperty(String textProperty) {
            if (StringUtils.hasText((String)textProperty)) {
                this.textProperty = textProperty;
            }
            return this;
        }

        public Builder constraintName(String constraintName) {
            if (StringUtils.hasText((String)constraintName)) {
                this.constraintName = constraintName;
            }
            return this;
        }

        public Builder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public Neo4jVectorStore build() {
            return new Neo4jVectorStore(this);
        }
    }

    public static enum Neo4jDistanceType {
        COSINE("cosine"),
        EUCLIDEAN("euclidean");

        public final String name;

        private Neo4jDistanceType(String name) {
            this.name = name;
        }
    }
}

