From e6421012965b68f04534c45350a818ef68cc20e9 Mon Sep 17 00:00:00 2001 From: tourtourigny Date: Wed, 20 Aug 2025 13:55:09 -0400 Subject: [PATCH 1/3] feat natural language query by REST API --- Cargo.lock | 185 +++-- Cargo.toml | 2 +- ai-search/Cargo.toml | 23 + ai-search/README.md | 12 + ai-search/ressources/traversal_prompt.md | 11 + ai-search/src/automatic_search.rs | 507 +++++++++++++ ai-search/src/full_ai_search.rs | 665 ++++++++++++++++++ ai-search/src/main.rs | 97 +++ grc20-core/src/mapping/entity/models.rs | 6 +- .../mapping/entity/search_with_traversals.rs | 2 +- .../src/mapping/entity/semantic_search.rs | 2 +- grc20-core/src/mapping/mod.rs | 2 +- grc20-core/src/mapping/relation/find_many.rs | 2 +- grc20-core/src/mapping/triple.rs | 10 +- .../resources/get_entity_info_description.md | 78 -- mcp-server/resources/instructions.md | 169 +---- .../name_search_entity_description.md | 39 - .../resources/search_entity_description.md | 69 +- .../search_relation_type_description.md | 22 - .../resources/search_space_description.md | 14 - .../resources/search_type_description.md | 12 - mcp-server/src/main.rs | 240 +++---- 22 files changed, 1584 insertions(+), 585 deletions(-) create mode 100644 ai-search/Cargo.toml create mode 100644 ai-search/README.md create mode 100644 ai-search/ressources/traversal_prompt.md create mode 100644 ai-search/src/automatic_search.rs create mode 100644 ai-search/src/full_ai_search.rs create mode 100644 ai-search/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 1bafabf..81f2b70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai-search" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "clap", + "fastembed", + "futures", + "grc20-core", + "grc20-sdk", + "rand 0.8.5", + "regex", + "rig-core", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "aligned-vec" version = "0.5.0" @@ -151,6 +172,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "as-any" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063" + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -669,7 +696,7 @@ dependencies = [ "serde_json", "serde_repr", "serde_urlencoded", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-util", "tower-service", @@ -720,9 +747,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" dependencies = [ "serde", ] @@ -1632,6 +1659,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "gloo-timers" version = "0.3.0" @@ -1662,7 +1695,7 @@ dependencies = [ "serde_json", "serde_with", "testcontainers", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tracing", "uuid", @@ -1798,10 +1831,10 @@ dependencies = [ "log", "native-tls", "rand 0.8.5", - "reqwest 0.12.12", + "reqwest 0.12.22", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", "ureq", "windows-sys 0.59.0", ] @@ -2032,21 +2065,28 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", + "futures-core", "futures-util", "http 1.2.0", "http-body 1.0.1", "hyper 1.6.0", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration 0.6.1", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -2331,8 +2371,8 @@ name = "ipfs" version = "0.1.0" dependencies = [ "prost", - "reqwest 0.12.12", - "thiserror 2.0.11", + "reqwest 0.12.22", + "thiserror 2.0.12", "tokio", "tracing", ] @@ -2343,6 +2383,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2568,9 +2618,9 @@ checksum = "db13adb97ab515a3691f56e4dbab09283d0b86cb45abd991d8634a9d6f501760" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libfuzzer-sys" @@ -2769,6 +2819,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3087,6 +3147,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" +dependencies = [ + "num-traits", +] + [[package]] name = "ort" version = "2.0.0-rc.9" @@ -3762,9 +3831,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "cbc931937e6ca3a06e3b6c0aa7841849b160a90351d6ab467a8b9b9959767531" dependencies = [ "base64 0.22.1", "bytes", @@ -3779,31 +3848,29 @@ dependencies = [ "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", - "ipnet", "js-sys", "log", "mime", + "mime_guess", "native-tls", - "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", - "system-configuration 0.6.1", "tokio", "tokio-native-tls", "tokio-util", "tower 0.5.2", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "windows-registry", ] [[package]] @@ -3818,6 +3885,29 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +[[package]] +name = "rig-core" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e57d16b6bc8ed1a4d7cea51782f1c6320e0a98b01a8de719540251a81b027ffb" +dependencies = [ + "as-any", + "async-stream", + "base64 0.22.1", + "bytes", + "futures", + "glob", + "mime_guess", + "ordered-float", + "reqwest 0.12.22", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.12", + "tracing", + "url", +] + [[package]] name = "ring" version = "0.17.8" @@ -3849,7 +3939,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-stream", "tokio-util", @@ -4294,7 +4384,7 @@ dependencies = [ "prometheus", "prost", "prost-types", - "reqwest 0.12.12", + "reqwest 0.12.22", "serde", "serde_json", "serde_path_to_error", @@ -4302,7 +4392,7 @@ dependencies = [ "substreams-utils", "test-log", "testcontainers", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tracing", "tracing-appender", @@ -4355,9 +4445,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", @@ -4663,7 +4753,7 @@ dependencies = [ "serde", "serde_json", "serde_with", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-stream", "tokio-tar", @@ -4682,11 +4772,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -4702,9 +4792,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -4807,7 +4897,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror 2.0.11", + "thiserror 2.0.12", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -5019,14 +5109,18 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.2" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ "bitflags 2.8.0", "bytes", + "futures-util", "http 1.2.0", + "http-body 1.0.1", + "iri-string", "pin-project-lite", + "tower 0.5.2", "tower-layer", "tower-service", ] @@ -5160,6 +5254,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.16" @@ -5525,32 +5625,31 @@ checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] name = "windows-registry" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "b3bab093bdd303a1240bb99b8aba8ea8a69ee19d34c9e2ef9594e708a4878820" dependencies = [ + "windows-link", "windows-result", "windows-strings", - "windows-targets 0.52.6", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-result", - "windows-targets 0.52.6", + "windows-link", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 29e4a9a..277fcaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,5 +8,5 @@ members = [ "web3-utils", "grc20-core", "grc20-macros", - "grc20-sdk", "mcp-server", + "grc20-sdk", "mcp-server", "ai-search", ] diff --git a/ai-search/Cargo.toml b/ai-search/Cargo.toml new file mode 100644 index 0000000..76f7ddf --- /dev/null +++ b/ai-search/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "ai-search" +version = "0.1.0" +edition = "2024" + +[dependencies] +axum= "0.8.4" +anyhow = "1.0.89" +clap = { version = "4.5.39", features = ["derive", "env"] } +fastembed = "4.8.0" +futures = "0.3.31" +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } + +grc20-core = { version = "0.1.0", path = "../grc20-core" } +grc20-sdk = { version = "0.1.0", path = "../grc20-sdk" } + +rand = "0.8" +regex = "1.10" +rig-core = "0.17.1" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/ai-search/README.md b/ai-search/README.md new file mode 100644 index 0000000..270f192 --- /dev/null +++ b/ai-search/README.md @@ -0,0 +1,12 @@ +# README ai-search + +## API +ai-search is a REST API that can be queried at the adress 127.0.0.0:3000. +The only available route is /question that takes a JSON containing the question for the knowledge graph. + + +## Start command +``` +cargo run --bin ai-search -- --neo4j-uri neo4j://localhost:7687 --neo4j-user neo4j --neo4j-pass neo4j --gemini-api-key + +``` \ No newline at end of file diff --git a/ai-search/ressources/traversal_prompt.md b/ai-search/ressources/traversal_prompt.md new file mode 100644 index 0000000..8d372eb --- /dev/null +++ b/ai-search/ressources/traversal_prompt.md @@ -0,0 +1,11 @@ +You are a Knowledge Graph traversal agent. Your goal is to answer the user’s question using the fewest steps possible. +You may: + +ID to discover an unexplored neighbor’s details. + + to give a partial answer if you have some important information. + + if no path can yield the answer. + +Never re-explore a node. Always choose the neighbor(s) most likely to contain the answer or get some more relevant information. Try to expand until you have an answer. You can explore multiple nodes and expansions at the same time with multiple explore. + diff --git a/ai-search/src/automatic_search.rs b/ai-search/src/automatic_search.rs new file mode 100644 index 0000000..e9a47ed --- /dev/null +++ b/ai-search/src/automatic_search.rs @@ -0,0 +1,507 @@ +use axum::{extract::Json, http::StatusCode}; +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use futures::{TryStreamExt, future::join_all, pin_mut}; +use grc20_core::{ + entity::{self, Entity, EntityFilter, EntityNode}, + mapping::{ + Query, QueryStream, RelationEdge, prop_filter, + triple::{self, SemanticSearchResult}, + }, + neo4rs, + relation::{self, RelationFilter}, + system_ids, +}; +use grc20_sdk::models::BaseEntity; +use regex::Regex; +use rig::{agent::Agent, completion::Prompt, providers::gemini::completion::CompletionModel}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +const EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::AllMiniLML6V2; + +pub struct AutomaticSearchAgent { + neo4j: neo4rs::Graph, + pub embedding_model: Arc, + pub fast_agent: Agent, + pub thinking_agent: Agent, +} + +impl AutomaticSearchAgent { + pub fn new(neo4j: neo4rs::Graph, gemini_api_key: &str) -> Self { + let preamble = "You are a knowledge graph helper to answer natural language questions."; + let gemini_client = rig::providers::gemini::Client::new(gemini_api_key); + + Self { + neo4j, + embedding_model: Arc::new( + TextEmbedding::try_new( + InitOptions::new(EMBEDDING_MODEL).with_show_download_progress(true), + ) + .expect("Failed to initialize embedding model"), + ), + fast_agent: gemini_client + .agent("gemini-2.5-flash-lite") + .temperature(0.4) + .preamble(preamble) + .build(), + thinking_agent: gemini_client + .agent("gemini-2.5-flash") + .temperature(0.4) + .preamble(preamble) + .build(), + } + } + + pub async fn natural_language_question( + &self, + question: String, + ) -> Result, StatusCode> { + let string_extraction_regex = Regex::new(r#""([^"]+)""#).unwrap(); + + let create_relation_filter = |search_result: SemanticSearchResult| { + RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(search_result.triple.entity))) + }; + + let main_entity = self.fast_agent.prompt(format!("Can you extract the main 1-3 Entities from which you can do a search from the original question. THE ANSWER SHOULD BE A SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. IF THERE IS MORE THAN ONE IMPORTANT CONCEPT YOU CAN EXTRACT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let entities: Vec = string_extraction_regex + .captures_iter(&main_entity) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + let number_answers = entities.len(); + + tracing::info!("important word found: {entities:?}"); + + let important_concepts = self.fast_agent.prompt(format!("Can you extract the all keywords and important concepts from the original question. You can also add keywords related that aren't directly in the question. THE ANSWER SHOULD BE A SERIE OF SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. OUTPUT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let concepts: Vec = string_extraction_regex + .captures_iter(&important_concepts) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + tracing::info!("important concepts are: {concepts:?}"); + + let concepts_embeddings: Vec> = concepts + .into_iter() + .map(|concept| { + self.embedding_model + .embed(vec![concept], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>() + }) + .collect(); + + let expanded_paths = join_all( + join_all(entities.into_iter().map(|entity| async { + let semantic_search = self.search(entity, Some(1)).await.unwrap_or_default(); + self.get_ids_from_search(semantic_search, create_relation_filter) + .await + .unwrap_or_default() + })) + .await + .into_iter() + .flatten() + .map(|sart_node_id| async { + self.automatic_explore_node(sart_node_id, concepts_embeddings.clone()) + .await + }), + ) + .await + .to_vec(); + + let answers_combined = join_all(expanded_paths.iter().map(|path| async { + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe full search of nodes:{}", path.clone() + ); + + self.thinking_agent + .prompt(agent_prompt) + .await.unwrap_or("Agent error".to_string()) + })).await.join("]\n["); + + let final_answer = if number_answers == 1 { + answers_combined + } else { + let final_answer_prompt = format!( + "BASE YOUR ANSWER ONLY ON THE PARTIAL ANSWERS! YOU ANSWER ONLY THE ORIGINAL QUESTION SINCE THE MAIN USER DOESN'T CARE ABOUT PARTIAL ANSWERS. GIVE A FULL COMPLETE AND DETAILLED ANSWER. From the given partial answer from different starting points, can you provide a final single answer to the original question: {question}\n Here are the different partial answers:[{answers_combined}]" + ); + self.fast_agent + .prompt(final_answer_prompt) + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + }; + + Ok(Json(final_answer)) + } + + async fn automatic_explore_node( + &self, + start_id: String, + concepts_embeddings: Vec>, + ) -> String { + let mut seen_ids = HashSet::new(); + + let start_name = self + .get_name_of_id(start_id.clone()) + .await + .unwrap_or("No name".to_string()); + + let mut path_root = PathNode { + id: start_id.clone(), + relation_name: "Start".to_string(), + relation_name_embedding: Vec::new(), + entity_name: start_name, + entity_name_embedding: Vec::new(), + nodes: Vec::new(), + is_inbound: false, + page_rank: 0.0, + depth: 0, + }; + + let mut stack: Vec<&mut PathNode> = vec![&mut path_root]; + + while let Some(current) = stack.pop() { + if !seen_ids.insert(current.id.clone()) { + continue; + } + + let neighbor_nodes = self + .automatic_expand_entity( + current.id.clone(), + current.depth, + concepts_embeddings.clone(), + ) + .await + .unwrap_or_default(); + + for neighbor in neighbor_nodes { + current.nodes.push(PathNode { + nodes: Vec::new(), + id: neighbor.id.clone(), + depth: neighbor.depth + 1, + ..neighbor + }); + } + + current + .nodes + .iter_mut() + .for_each(|neighbor| stack.insert(0, neighbor)); + } + + self.create_path(&path_root, 0).await + } + + async fn automatic_expand_entity( + &self, + base_id: String, + depth: usize, + concepts_embeddings: Vec>, + ) -> Result, StatusCode> { + const MAX_SAME_RELATION: usize = 10; + const NAME_IMPORTANCE_FACTOR: f64 = 0.9; + const MAX_DISPLAYED_NEIGHBORS: usize = 35; + const EMBEDDING_DISTANCE_THRESHOLD: f64 = 0.65; + + let mut relation_count = HashMap::new(); + + let mut relations = self + .extract_path_nodes_to_neighbors(base_id.clone(), depth, false) + .await + .unwrap_or_default(); + relations.extend( + self.extract_path_nodes_to_neighbors(base_id.clone(), depth, true) + .await + .unwrap_or_default(), + ); + + let mut relations_with_scores: Vec<(f64, PathNode)> = relations + .into_iter() + .map(|path_node| { + // Compute all cosine distances to concept embeddings and take the best (lowest) + let distance_score = concepts_embeddings + .iter() + .map(|concept| { + let name_dist = + self.cosine_distance_unit(concept, &path_node.entity_name_embedding); + let rel_dist = + self.cosine_distance_unit(concept, &path_node.relation_name_embedding); + name_dist.min(rel_dist) * NAME_IMPORTANCE_FACTOR + + name_dist.max(rel_dist) * (1.0 - NAME_IMPORTANCE_FACTOR) + }) + .fold(f64::INFINITY, |a, b| a.min(b)); + + let depth_penalty = 0.02 * depth as f64; + let depth_centered = 3.5 - depth as f64; + let depth_factor = depth_centered.powf(1.0 / 5.0) * 20.0; + let page_rank_clipped = path_node.page_rank.clamp(-0.5, 0.8); + let final_score = + depth_penalty + distance_score * (0.9 - page_rank_clipped / depth_factor); // /depth_factor + + (final_score, path_node) + }) + .collect(); + + relations_with_scores + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Less)); + + let relations: Vec = relations_with_scores + .into_iter() + .filter(|(score, _)| score < &EMBEDDING_DISTANCE_THRESHOLD) + .filter_map(|(_, path_node)| { + let count = relation_count + .entry(path_node.relation_name.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + if *count <= MAX_SAME_RELATION { + Some(path_node) + } else { + None + } + }) + .take(MAX_DISPLAYED_NEIGHBORS) + .collect(); + + Ok(relations) + } + + async fn search( + &self, + query: String, + limit: Option, + ) -> Result, StatusCode> { + let embedding = self + .embedding_model + .embed(vec![&query], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>(); + + let limit = limit.unwrap_or(10); + + let semantic_search_triples = triple::search(&self.neo4j, embedding) + .limit(limit) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(semantic_search_triples) + } + + async fn get_ids_from_search( + &self, + search_triples: Vec, + create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, + ) -> Result, StatusCode> { + let mut seen_ids: HashSet = HashSet::new(); + let mut result_ids: Vec = Vec::new(); + + for semantic_search_triple in search_triples { + let filtered_for_types = relation::find_many::>(&self.neo4j) + .filter(create_relation_filter(semantic_search_triple)) + .send() + .await; + + //We only need to get the first relation since they would share the same entity id + if let Ok(stream) = filtered_for_types { + pin_mut!(stream); + if let Some(edge) = stream.try_next().await.ok().flatten() { + let id = edge.from.id; + if seen_ids.insert(id.clone()) { + result_ids.push(id); + } + } + } + } + Ok(result_ids) + } + + async fn extract_path_nodes_to_neighbors( + &self, + id: String, + depth: usize, + is_inbound: bool, + ) -> Result, StatusCode> { + const MAX_RELATIONS_CONSIDERED: usize = 100; + + let in_filter = relation::RelationFilter::default() + .to_(EntityFilter::default().id(prop_filter::value(id.clone()))); + let out_filter = relation::RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(id.clone()))); + + let relations = relation::find_many::>(&self.neo4j) + .filter(if is_inbound { in_filter } else { out_filter }) + .limit(MAX_RELATIONS_CONSIDERED) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let related_nodes = join_all(relations.into_iter().map(|result| async move { + let page_rank: f64; + let id: &str; + + if is_inbound { + id = &result.from.id; + page_rank = result.from.system_properties.page_rank.unwrap_or(0.0); + } else { + id = &result.to.id; + page_rank = result.to.system_properties.page_rank.unwrap_or(0.0); + }; + let (entity_name_embedding, entity_name) = self + .get_name_and_embedding_of_id(id.to_string()) + .await + .unwrap_or((Vec::new(), "No name".to_string())); + let (relation_name_embedding, relation_name) = self + .get_name_and_embedding_of_id(result.relation_type.clone()) + .await + .unwrap_or((Vec::new(), "No name".to_string())); + + PathNode { + id: id.to_string(), + relation_name, + relation_name_embedding, + entity_name, + entity_name_embedding, + nodes: Vec::new(), + is_inbound, + page_rank, + depth: depth + 1, + } + })) + .await + .to_vec(); + + Ok(related_nodes) + } + + async fn create_path(&self, path_node: &PathNode, depth: usize) -> String { + const MAX_DISPLAY_LENGTH: usize = 8; + + let shorten_string = |word: String, max_length: usize| { + let splitted: Vec<&str> = word.split_whitespace().collect(); + if splitted.len() > max_length { + splitted[..max_length].join(" ") + "..." + } else { + splitted.join(" ") + } + }; + + let mut base = format!( + "{}{}-[{}]-{}{}\n", + " ".repeat(depth), + if path_node.is_inbound { "<" } else { "" }, + shorten_string(path_node.relation_name.clone(), MAX_DISPLAY_LENGTH), + if path_node.is_inbound { "" } else { ">" }, + shorten_string(path_node.entity_name.clone(), MAX_DISPLAY_LENGTH), + ); + + let rest = join_all( + path_node + .nodes + .iter() + .map(|node| async { self.create_path(node, depth + 1).await }), + ) + .await + .join(""); + + base.push_str(&rest); + base + } + + async fn get_name_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + entity + .attributes + .name + .ok_or(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn get_name_and_embedding_of_id( + &self, + id: String, + ) -> Result<(Vec, String), StatusCode> { + let triple_result = triple::find_many(&self.neo4j) + .entity_id(prop_filter::value(&id)) + .attribute_id(prop_filter::value(system_ids::NAME_ATTRIBUTE)) + .limit(1) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + if let Some(name_entity) = triple_result.first() { + if let Some(embedding) = &name_entity.embedding { + return Ok((embedding.clone(), name_entity.value.value.clone())); + } + } + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + + #[inline(always)] + fn cosine_distance_unit(&self, a: &[f64], b: &[f64]) -> f64 { + let mut dot = 0.0; + for (&x, &y) in a.iter().zip(b.iter()) { + dot += x * y; + } + 1.0 - dot + } +} + +#[derive(Clone, Debug)] +struct PathNode { + id: String, + relation_name: String, + relation_name_embedding: Vec, + entity_name: String, + entity_name_embedding: Vec, + nodes: Vec, + is_inbound: bool, + page_rank: f64, + depth: usize, +} diff --git a/ai-search/src/full_ai_search.rs b/ai-search/src/full_ai_search.rs new file mode 100644 index 0000000..2d96228 --- /dev/null +++ b/ai-search/src/full_ai_search.rs @@ -0,0 +1,665 @@ +use anyhow::Error; +use axum::{extract::Json, http::StatusCode}; +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use futures::{TryStreamExt, future::join_all, pin_mut}; +use grc20_core::{ + entity::{self, Entity, EntityFilter, EntityNode}, + mapping::{ + Query, QueryStream, RelationEdge, prop_filter, + triple::{self, SemanticSearchResult}, + }, + neo4rs, + relation::{self, RelationFilter}, +}; +use grc20_sdk::models::BaseEntity; +use rand::{Rng, distributions::Alphanumeric}; +use regex::Regex; +use rig::{agent::Agent, completion::Prompt, providers::gemini::completion::CompletionModel}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; +use tokio::sync::Mutex; + +const EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::AllMiniLML6V2; + +pub struct FullAISearchAgent { + neo4j: neo4rs::Graph, + pub embedding_model: Arc, + pub fast_agent: Agent, + pub thinking_agent: Agent, +} + +#[derive(Clone, Debug)] +struct PathNode { + id: String, + relation_name: String, + entity_name: String, + nodes: Vec, + is_explored: bool, + is_inbound: bool, + is_hidden: bool, + depth: usize, +} + +type PathRef = Arc>; + +impl FullAISearchAgent { + pub fn new(neo4j: neo4rs::Graph, gemini_api_key: &str) -> Self { + let traversal_system_prompt = include_str!("../ressources/traversal_prompt.md"); + let preamble = "You are a knowledge graph helper to answer natural language questions."; + let gemini_client = rig::providers::gemini::Client::new(gemini_api_key); + + Self { + neo4j, + embedding_model: Arc::new( + TextEmbedding::try_new( + InitOptions::new(EMBEDDING_MODEL).with_show_download_progress(true), + ) + .expect("Failed to initialize embedding model"), + ), + fast_agent: gemini_client + .agent("gemini-2.5-flash-lite") + .temperature(0.4) + .preamble(traversal_system_prompt) + .build(), + thinking_agent: gemini_client + .agent("gemini-2.5-flash") + .temperature(0.4) + .preamble(preamble) + .build(), + } + } + + pub async fn natural_language_question( + &self, + question: String, + ) -> Result, StatusCode> { + let string_extraction_regex = Regex::new(r#""([^"]+)""#).unwrap(); + + let create_relation_filter = |search_result: SemanticSearchResult| { + RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(search_result.triple.entity))) + }; + + let main_entity = self.thinking_agent.prompt(format!("Can you extract the main 1-3 Entities from which you can do a search from the original question. THE ANSWER SHOULD BE A SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. IF THERE IS MORE THAN ONE IMPORTANT CONCEPT YOU CAN EXTRACT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let entities: Vec = string_extraction_regex + .captures_iter(&main_entity) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + let number_answers = entities.len(); + + tracing::info!("important word found: {entities:?}"); + + let important_concepts = self.thinking_agent.prompt(format!("Can you extract the all keywords and important concepts from the original question. You can also add keywords related that aren't directly in the question. THE ANSWER SHOULD BE A SERIE OF SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. OUTPUT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let concepts: Vec = string_extraction_regex + .captures_iter(&important_concepts) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + tracing::info!("important concepts are: {concepts:?}"); + + let concepts_embeddings: Vec> = self + .embedding_model + .embed(concepts, None) + .expect("Failed to get embedding") + .into_iter() + .map(|vec| vec.into_iter().map(|v| v as f64).collect()) + .collect::>(); + + let expanded_paths = join_all( + join_all(entities.into_iter().map(|entity| async { + let semantic_search = self.search(entity, Some(1)).await.unwrap_or_default(); + self.get_ids_from_search(semantic_search, create_relation_filter) + .await + .unwrap_or_default() + })) + .await + .into_iter() + .flatten() + .map(|start_node_id| async { + self.explore_node(question.clone(), concepts_embeddings.clone(), start_node_id) + .await + }), + ) + .await + .to_vec(); + + let answers_combined = join_all(expanded_paths.iter().map(|path| async { + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe full search of nodes:{}", path.clone() + ); + + self.thinking_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()) + })).await + .join("]\n["); + + let final_answer = if number_answers == 1 { + Ok(answers_combined) + } else { + let final_answer_prompt = format!( + "BASE YOUR ANSWER ONLY ON THE PARTIAL ANSWERS! YOU ANSWER ONLY THE ORIGINAL QUESTION SINCE THE MAIN USER DOESN'T CARE ABOUT PARTIAL ANSWERS. GIVE A FULL COMPLETE AND DETAILLED ANSWER. From the given partial answer from different starting points, can you provide a final single answer to the original question: {question}\n Here are the different partial answers:[{answers_combined}]" + ); + self.thinking_agent.prompt(final_answer_prompt).await + }; + + match final_answer { + Ok(answer) => Ok(Json(answer)), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } + } + + async fn create_path(&self, path_node: &PathNode, depth: usize) -> String { + if path_node.is_hidden || depth > 10 { + return "".to_string(); + } + const MAX_DISPLAY_LENGTH: usize = 8; + + let shorten_string = |word: String, max_length: usize| { + let splitted: Vec<&str> = word.split_whitespace().collect(); + if splitted.len() > max_length && !path_node.is_explored { + splitted[..max_length].join(" ") + "..." + } else { + splitted.join(" ") + } + }; + + let mut base = format!( + "{}{}-[{}]-{}{}{}({})\n", + " ".repeat(depth), + if path_node.is_inbound { "<" } else { "" }, + shorten_string(path_node.relation_name.clone(), MAX_DISPLAY_LENGTH), + if path_node.is_inbound { "" } else { ">" }, + if path_node.is_explored { "+" } else { "?" }, + shorten_string(path_node.entity_name.clone(), MAX_DISPLAY_LENGTH), + path_node.id.chars().take(6).collect::(), + ); + + let rest = join_all(path_node.nodes.iter().map(|node| async { + self.create_path(&node.lock().await.clone(), depth + 1) + .await + })) + .await + .join(""); + + base.push_str(&rest); + base + } + + async fn explore_node( + &self, + question: String, + concepts_embeddings: Vec>, + start_id: String, + ) -> String { + let explore_regex = Regex::new(r"(.*?)").unwrap(); + let answer_regex = Regex::new(r"(.*?)").unwrap(); + let end_regex = Regex::new(r"").unwrap(); + + let mut full_id_resolver: HashMap = HashMap::new(); + let mut expansions: HashMap> = HashMap::new(); + let mut seen_ids = HashSet::new(); + let mut seen_twice = HashSet::new(); + let mut explores = Vec::new(); + let mut partial_answers = Vec::new(); + let mut potential_explore = Vec::new(); + + let start_name = self + .get_name_of_id(start_id.clone()) + .await + .unwrap_or("No name".to_string()); + + let path_root: PathRef = Arc::new(Mutex::new(PathNode { + id: start_id.clone(), + relation_name: "Start".to_string(), + entity_name: start_name, + nodes: Vec::new(), + is_explored: false, + is_inbound: false, + is_hidden: false, + depth: 0, + })); + + potential_explore.push(path_root.clone()); + + let mut stack: Vec = vec![path_root.clone()]; + full_id_resolver.insert(start_id.chars().take(6).collect(), start_id); + + while let Some(current) = stack.pop() { + let current_node = current.lock().await; + let mut warning = false; + + if !seen_ids.insert(current_node.id.clone()) { + warning = true; + if !seen_twice.insert(current_node.id.clone()) { + continue; + } + } + + if partial_answers.len() >= 10 { + break; + } + + let data = format!( + "name: {}, description:{}", + current_node.entity_name, + self.get_description_of_id(current_node.id.clone()) + .await + .unwrap_or("No description".to_string()) + ); + + let (neighbor_nodes, hidden_options) = self + .expand_entity(current_node.id.clone(), concepts_embeddings.clone()) + .await + .unwrap_or_default(); + + drop(current_node); + let mut current_node_mut = current.lock().await; + current_node_mut.is_explored = true; + + for neighbor in neighbor_nodes.clone() { + let new_node = Arc::new(Mutex::new(PathNode { + nodes: Vec::new(), + is_explored: seen_ids.contains(&neighbor.id.clone()), + id: neighbor.id.clone(), + depth: neighbor.depth + 1, + ..neighbor + })); + full_id_resolver.insert(neighbor.id.chars().take(6).collect(), neighbor.id); + + if !warning && !neighbor.is_explored { + current_node_mut.nodes.push(new_node.clone()); + potential_explore.push(new_node); + } + } + + for (relation_name, hidden_path_nodes) in hidden_options { + let expand_id: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(6) + .map(char::from) + .collect(); + + let expansion_node = Arc::new(Mutex::new(PathNode { + id: expand_id.clone(), + relation_name: relation_name.clone(), + entity_name: format!( + "Expand {} more {relation_name}...", + hidden_path_nodes.len() + ), + nodes: Vec::new(), + is_explored: false, + is_inbound: false, + is_hidden: false, + depth: 0, + })); + if !warning { + current_node_mut.nodes.push(expansion_node.clone()); + } + + for neighbor in hidden_path_nodes { + let new_node = Arc::new(Mutex::new(PathNode { + nodes: Vec::new(), + is_explored: seen_ids.contains(&neighbor.id.clone()), + id: neighbor.id.clone(), + is_hidden: true, + ..neighbor + })); + full_id_resolver.insert(neighbor.id.chars().take(6).collect(), neighbor.id); + expansions + .entry(expand_id.clone()) + .or_default() + .push(new_node.clone()); + + if !warning && !neighbor.is_explored { + current_node_mut.nodes.push(new_node.clone()); + potential_explore.push(new_node); + } + } + } + + drop(current_node_mut); + + let named_path: String = { + let snapshot_of_path = path_root.lock().await; + self.create_path(&snapshot_of_path, 0).await + }; + + let warning_message = if warning { + "YOU HAVE ALREADY SEEN THIS NODE! YOU WON'T BE ABLE TO SEE IT AGAIN! SEEING IT AGAIN WON'T GIVE YOU MORE INFORMATION." + } else { + "" + }; + + let agent_prompt = format!( + "{warning_message}\nThe current explored node is:\n{data}\nThe original question is:\n{question}\nThe full exploration is:{named_path}" + ); + + let prompt_answer = self + .fast_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()); + + tracing::info!("Agent answer: {prompt_answer}"); + + for explore in explore_regex.captures_iter(&prompt_answer) { + if let Some(full_id) = full_id_resolver.get(&explore[1]) { + explores.push(full_id.clone()); + } + if let Some(hidden_nodes) = expansions.get(&explore[1]) { + let mut first = true; + for node in hidden_nodes { + let mut current_node_mut = node.lock().await; + current_node_mut.is_hidden = false; + if first { + explores.push(current_node_mut.id.clone()); + first = false; + } + drop(current_node_mut); + } + } + } + + if let Some(answer) = answer_regex.captures(&prompt_answer) { + partial_answers.push(answer[1].to_string()); + stack.push(current.clone()); + } + + if end_regex.captures(&prompt_answer).is_some() { + break; + } + + for potential_exploration_node in &potential_explore { + let node = potential_exploration_node.lock().await; + + let insert_node = explores.contains(&node.id); + explores.retain(|item| item != &node.id); + drop(node); + if insert_node { + // BFS + stack.insert(0, potential_exploration_node.clone()); + } + } + } + + let named_path: String = { + let snapshot_of_path = path_root.lock().await; + self.create_path(&snapshot_of_path, 0).await + }; + + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe partial_answers are:[{}]\nThe full path of nodes:{named_path}", + partial_answers.join("]\n[") + ); + + self.thinking_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()) + } + + async fn expand_entity( + &self, + base_id: String, + concepts_embeddings: Vec>, + ) -> Result<(Vec, HashMap>), Error> { + const IMPORTANCE_FACTOR: f64 = 0.9; + const MAX_SAME_RELATION: usize = 4; + const MAX_DIPLAYED_NEIGHBORS: usize = 15; + const EMBEDDING_DISTANCE_THRESHOLD: f64 = 1.5; + + let mut relation_count: HashMap = HashMap::new(); + let mut hidden_options: HashMap> = HashMap::new(); + + let mut relations = self + .extract_path_nodes_to_neighbors(base_id.clone(), 0, false) + .await + .unwrap_or_default(); + + relations.extend( + self.extract_path_nodes_to_neighbors(base_id.clone(), 0, true) + .await + .unwrap_or_default(), + ); + + let mut relations_with_scores: Vec<(f64, PathNode)> = relations + .into_iter() + .map(|path_node| { + let name_embedding = self.create_embedding(vec![&path_node.entity_name]); + let relation_embedding = self.create_embedding(vec![&path_node.relation_name]); + + // Compute all cosine distances to concept embeddings and take the best (lowest) + let min_name_distance = concepts_embeddings + .iter() + .map(|concept| self.cosine_distance(concept, &name_embedding)) + .fold(f64::INFINITY, |a, b| a.min(b)); + + // Compute all cosine distances to concept embeddings and take the best (lowest) + let min_rel_distance = concepts_embeddings + .iter() + .map(|concept| self.cosine_distance(concept, &relation_embedding)) + .fold(f64::INFINITY, |a, b| a.min(b)); + + // Many times, the same relation name is used. It now differentiate on the best entity name + let distance_score = if min_name_distance < min_rel_distance { + min_name_distance * IMPORTANCE_FACTOR + + min_rel_distance * (1.0 - IMPORTANCE_FACTOR) + } else { + min_name_distance * (1.0 - IMPORTANCE_FACTOR) + + min_rel_distance * IMPORTANCE_FACTOR + }; + + (distance_score, path_node) + }) + .collect(); + + relations_with_scores + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Less)); + + let relations: Vec = relations_with_scores + .into_iter() + .filter(|(score, _)| score < &EMBEDDING_DISTANCE_THRESHOLD) + .filter_map(|(_, path_node)| { + let count = relation_count + .entry(path_node.relation_name.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + if *count <= MAX_SAME_RELATION { + Some(path_node) + } else { + hidden_options + .entry(path_node.relation_name.clone()) + .or_default() + .push(path_node); + None + } + }) + .take(MAX_DIPLAYED_NEIGHBORS) + .collect(); + + Ok((relations, hidden_options)) + } + + async fn extract_path_nodes_to_neighbors( + &self, + id: String, + depth: usize, + is_inbound: bool, + ) -> Result, StatusCode> { + const MAX_RELATIONS_CONSIDERED: usize = 100; + + let in_filter = relation::RelationFilter::default() + .to_(EntityFilter::default().id(prop_filter::value(id.clone()))); + let out_filter = relation::RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(id.clone()))); + + let relations = relation::find_many::>(&self.neo4j) + .filter(if is_inbound { in_filter } else { out_filter }) + .limit(MAX_RELATIONS_CONSIDERED) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let related_nodes = join_all(relations.into_iter().map(|result| async move { + let id = if is_inbound { + result.from.id + } else { + result.to.id + }; + + let entity_name = self + .get_name_of_id(id.to_string()) + .await + .unwrap_or("No name".to_string()); + + let relation_name = self + .get_name_of_id(result.relation_type.clone()) + .await + .unwrap_or("No name".to_string()); + + PathNode { + id: id.to_string(), + relation_name, + entity_name, + nodes: Vec::new(), + is_inbound, + is_explored: false, + is_hidden: false, + depth: depth + 1, + } + })) + .await + .to_vec(); + + Ok(related_nodes) + } + + async fn get_name_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + entity + .attributes + .name + .ok_or(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn get_description_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(entity + .attributes + .description + .unwrap_or("No description".to_string())) + } + + async fn search( + &self, + query: String, + limit: Option, + ) -> Result, StatusCode> { + let embedding = self + .embedding_model + .embed(vec![&query], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>(); + + let limit = limit.unwrap_or(10); + + let semantic_search_triples = triple::search(&self.neo4j, embedding) + .limit(limit) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(semantic_search_triples) + } + + async fn get_ids_from_search( + &self, + search_triples: Vec, + create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, + ) -> Result, StatusCode> { + let mut seen_ids: HashSet = HashSet::new(); + let mut result_ids: Vec = Vec::new(); + + for semantic_search_triple in search_triples { + let filtered_for_types = relation::find_many::>(&self.neo4j) + .filter(create_relation_filter(semantic_search_triple)) + .send() + .await; + + //We only need to get the first relation since they would share the same entity id + if let Ok(stream) = filtered_for_types { + pin_mut!(stream); + if let Some(edge) = stream.try_next().await.ok().flatten() { + let id = edge.from.id; + if seen_ids.insert(id.clone()) { + result_ids.push(id); + } + } + } + } + Ok(result_ids) + } + + #[inline(always)] + fn create_embedding(&self, name: Vec<&str>) -> Vec { + self.embedding_model + .embed(name, None) + .expect("Failed to get embedding") + .pop() + .expect("No embedding found") + .into_iter() + .map(|v| v as f64) + .collect::>() + } + + fn cosine_distance(&self, a: &[f64], b: &[f64]) -> f64 { + let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let norm_a: f64 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f64 = b.iter().map(|x| x * x).sum::().sqrt(); + 1.0 - (dot_product / (norm_a * norm_b)) + } +} diff --git a/ai-search/src/main.rs b/ai-search/src/main.rs new file mode 100644 index 0000000..07dd7f2 --- /dev/null +++ b/ai-search/src/main.rs @@ -0,0 +1,97 @@ +use axum::{Router, extract::Json, http::StatusCode, routing::post}; +use clap::{Args, Parser}; +use grc20_core::neo4rs; +use std::sync::Arc; +use tracing_subscriber::{ + layer::SubscriberExt, + util::SubscriberInitExt, + {self}, +}; + +use crate::{automatic_search::AutomaticSearchAgent, full_ai_search::FullAISearchAgent}; + +mod automatic_search; +mod full_ai_search; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let args = AppArgs::parse(); + + let neo4j = neo4rs::Graph::new( + &args.neo4j_args.neo4j_uri, + &args.neo4j_args.neo4j_user, + &args.neo4j_args.neo4j_pass, + ) + .await?; + + let full_ai_search_agent = + Arc::new(FullAISearchAgent::new(neo4j.clone(), &args.gemini_api_key)); + let automatic_search_agent = Arc::new(AutomaticSearchAgent::new( + neo4j.clone(), + &args.gemini_api_key, + )); + + let app = Router::new() + .route("/question_ai", post(handler_full_ai_search)) + .route("/question", post(handler_automatic_search)) + .with_state((full_ai_search_agent, automatic_search_agent)); + + // run our app with hyper, listening globally on port 3000 + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + axum::serve(listener, app).await.unwrap(); + Ok(()) +} + +async fn handler_full_ai_search( + axum::extract::State((search_agent, _)): axum::extract::State<( + Arc, + Arc, + )>, + Json(question): Json, +) -> Result, StatusCode> { + tracing::info!("The question asked to the knowledge graph is: {question}"); + search_agent.natural_language_question(question).await +} + +async fn handler_automatic_search( + axum::extract::State((_, search_agent)): axum::extract::State<( + Arc, + Arc, + )>, + Json(question): Json, +) -> Result, StatusCode> { + tracing::info!("The question asked to the knowledge graph is: {question}"); + search_agent.natural_language_question(question).await +} + +#[derive(Debug, Parser)] +#[command(name = "stdout", version, about, arg_required_else_help = true)] +struct AppArgs { + #[clap(flatten)] + neo4j_args: Neo4jArgs, + #[arg(long)] + gemini_api_key: String, +} + +#[derive(Debug, Args)] +struct Neo4jArgs { + /// Neo4j database host + #[arg(long)] + neo4j_uri: String, + + /// Neo4j database user name + #[arg(long)] + neo4j_user: String, + + /// Neo4j database user password + #[arg(long)] + neo4j_pass: String, +} diff --git a/grc20-core/src/mapping/entity/models.rs b/grc20-core/src/mapping/entity/models.rs index f1e1dfc..51dc307 100644 --- a/grc20-core/src/mapping/entity/models.rs +++ b/grc20-core/src/mapping/entity/models.rs @@ -233,7 +233,7 @@ impl Entity { } } -#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq)] pub struct SystemProperties { #[serde(rename = "82nP7aFmHJLbaPFszj2nbx")] // CREATED_AT_TIMESTAMP pub created_at: DateTime, @@ -243,6 +243,8 @@ pub struct SystemProperties { pub updated_at: DateTime, #[serde(rename = "7pXCVQDV9C7ozrXkpVg8RJ")] // UPDATED_AT_BLOCK pub updated_at_block: String, + #[serde(rename = "pagerank")] + pub page_rank: Option, } impl From for SystemProperties { @@ -252,6 +254,7 @@ impl From for SystemProperties { created_at_block: block.block_number.to_string(), updated_at: block.timestamp, updated_at_block: block.block_number.to_string(), + page_rank: None, } } } @@ -263,6 +266,7 @@ impl Default for SystemProperties { created_at_block: "0".to_string(), updated_at: Default::default(), updated_at_block: "0".to_string(), + page_rank: None, } } } diff --git a/grc20-core/src/mapping/entity/search_with_traversals.rs b/grc20-core/src/mapping/entity/search_with_traversals.rs index ade6741..1e08c76 100644 --- a/grc20-core/src/mapping/entity/search_with_traversals.rs +++ b/grc20-core/src/mapping/entity/search_with_traversals.rs @@ -86,7 +86,7 @@ impl SearchWithTraversals { fn subquery(&self) -> QueryBuilder { const QUERY: &str = r#" - CALL db.index.vector.queryNodes('vector_index', $limit * $effective_search_ratio, $vector) + CALL db.index.vector.queryNodes('vector_index', $effective_search_ratio, $vector) YIELD node AS n, score AS score WHERE score > $threshold MATCH (e:Entity) -[r:ATTRIBUTE]-> (n) diff --git a/grc20-core/src/mapping/entity/semantic_search.rs b/grc20-core/src/mapping/entity/semantic_search.rs index 22c181a..0aa1de1 100644 --- a/grc20-core/src/mapping/entity/semantic_search.rs +++ b/grc20-core/src/mapping/entity/semantic_search.rs @@ -77,7 +77,7 @@ impl SemanticSearchQuery { fn subquery(&self) -> QueryBuilder { const QUERY: &str = r#" - CALL db.index.vector.queryNodes('vector_index', $limit * $effective_search_ratio, $vector) + CALL db.index.vector.queryNodes('vector_index', $effective_search_ratio, $vector) YIELD node AS n, score AS score MATCH (e:Entity) -[r:ATTRIBUTE]-> (n) "#; diff --git a/grc20-core/src/mapping/mod.rs b/grc20-core/src/mapping/mod.rs index dc9f1fd..3ee43e4 100644 --- a/grc20-core/src/mapping/mod.rs +++ b/grc20-core/src/mapping/mod.rs @@ -28,7 +28,7 @@ pub use value::{Options, Value, ValueType}; use crate::{error::DatabaseError, indexer_ids}; -pub const EFFECTIVE_SEARCH_RATIO: f64 = 100000.0; +pub const EFFECTIVE_SEARCH_RATIO: f64 = 1000000.0; pub fn new_version_index(block_number: u64, idx: usize) -> String { format!("{block_number:016}:{idx:04}") diff --git a/grc20-core/src/mapping/relation/find_many.rs b/grc20-core/src/mapping/relation/find_many.rs index ea8f049..c3e94b0 100644 --- a/grc20-core/src/mapping/relation/find_many.rs +++ b/grc20-core/src/mapping/relation/find_many.rs @@ -123,8 +123,8 @@ impl FindManyQuery { .r#where(self.version.subquery("r")), ) .subquery(self.filter.subquery("r", "from", "to")) - .subquery("ORDER BY r.index") .limit(self.limit) + .subquery("ORDER BY r.index") .skip_opt(self.skip) } diff --git a/grc20-core/src/mapping/triple.rs b/grc20-core/src/mapping/triple.rs index c917c9b..0cc95d2 100644 --- a/grc20-core/src/mapping/triple.rs +++ b/grc20-core/src/mapping/triple.rs @@ -526,6 +526,7 @@ pub struct FindManyQuery { value_type: Option>, entity_id: Option>, space_id: Option>, + limit: usize, space_version: VersionFilter, } @@ -538,6 +539,7 @@ impl FindManyQuery { value_type: None, entity_id: None, space_id: None, + limit: 100, space_version: VersionFilter::default(), } } @@ -567,6 +569,11 @@ impl FindManyQuery { self } + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + pub fn space_version(mut self, space_version: impl Into) -> Self { self.space_version.version_mut(space_version.into()); self @@ -595,6 +602,7 @@ impl FindManyQuery { ) .r#where(self.space_version.subquery("r")), ) + .limit(self.limit) .subquery("RETURN n{.*, entity: e.id}") } } @@ -688,7 +696,7 @@ impl QueryStream for SemanticSearchQuery { { const QUERY: &str = const_format::formatcp!( r#" - CALL db.index.vector.queryNodes('vector_index', $limit * $effective_search_ratio, $vector) + CALL db.index.vector.queryNodes('vector_index', $effective_search_ratio, $vector) YIELD node AS n, score AS score ORDER BY score DESC LIMIT $limit diff --git a/mcp-server/resources/get_entity_info_description.md b/mcp-server/resources/get_entity_info_description.md index 97bf061..6786ba3 100644 --- a/mcp-server/resources/get_entity_info_description.md +++ b/mcp-server/resources/get_entity_info_description.md @@ -23,63 +23,9 @@ ToolResult> "name": "SF Mayor Lurie launching police task force to counter crime in core downtown areas", "relation_id": "8ESicJHiNJ28VGL5u34A5q", "relation_type": "Related spaces" - }, - { - "id": "6wAoNdGVbweKi2JRPZP4bX", - "name": "San Francisco Independent Film Festival", - "relation_id": "TH5Tu5Y5nacvREvAQRvcR2", - "relation_type": "Related spaces" - }, - { - "id": "8VCHYDURDStwuTCUBjWLQa", - "name": "Product Engineer at Geo", - "relation_id": "KPTqdNpCusxfM37KbKPX8w", - "relation_type": "Related spaces" - }, - { - "id": "NcQ3h9jeJSavVd8iFsUxvD", - "name": "Senior Civil Engineer @ Golden Gate Bridge, Highway & Transportation District", - "relation_id": "AqpNtJ3XxaY4fqRCyoXbdt", - "relation_type": "Cities" - }, - { - "id": "4ojV4dS1pV2tRnzXTpcMKJ", - "name": "Senior Plan Check Engineer (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "3AX4j43nywT5eBRV3s6AXi", - "relation_type": "Cities" - }, - { - "id": "QoakYWCuv85FVuYdSmonxr", - "name": "Senior Civil Engineer - Land Development (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "8GEF1i3LK4Z56THjE8dVku", - "relation_type": "Cities" - }, - { - "id": "JuV7jLoypebzLhkma6oZoU", - "name": "Lead Django Backend Engineer @ Textme Inc", - "relation_id": "46aBsQyBq15DimJ2i1DX4a", - "relation_type": "Cities" - }, - { - "id": "RTmcYhLVmmfgUn9L3D1J3y", - "name": "Chief Engineer @ Wyndham Hotels & Resorts", - "relation_id": "8uYxjzkkdjskDQAeTQomvc", - "relation_type": "Cities" } ], "outbound_relations": [ - { - "id": "CUoEazCD7EmzXPTFFY8gGY", - "name": "No name", - "relation_id": "5WeSkkE1XXvGJGmXj9VUQ8", - "relation_type": "Cover" - }, - { - "id": "7gzF671tq5JTZ13naG4tnr", - "name": "Space", - "relation_id": "WUZCXE1UGRtxdNQpGug8Tf", - "relation_type": "Types" - }, { "id": "D6Wy4bdtdoUrG3PDZceHr", "name": "City", @@ -91,30 +37,6 @@ ToolResult> "name": "Upcoming events", "relation_id": "V1ikGW9riu7dAP8rMgZq3u", "relation_type": "Blocks" - }, - { - "id": "T6iKbwZ17iv4dRdR9Qw7qV", - "name": "Trending restaurants", - "relation_id": "CvGXCmGXE7ofsgZeWad28p", - "relation_type": "Blocks" - }, - { - "id": "X18WRE36mjwQ7gu3LKaLJS", - "name": "Neighborhoods", - "relation_id": "Uxpsee9LoTgJqMFfAQyJP6", - "relation_type": "Blocks" - }, - { - "id": "HeC2pygci2tnvjTt5aEnBV", - "name": "Top goals", - "relation_id": "5WMTAzCnZH9Bsevou9GQ3K", - "relation_type": "Blocks" - }, - { - "id": "5YtYFsnWq1jupvh5AjM2ni", - "name": "Culture", - "relation_id": "5TmxfepRr1THMRkGWenj5G", - "relation_type": "Tabs" } ] } diff --git a/mcp-server/resources/instructions.md b/mcp-server/resources/instructions.md index 9aeec45..e70d5be 100644 --- a/mcp-server/resources/instructions.md +++ b/mcp-server/resources/instructions.md @@ -5,181 +5,42 @@ You should use it for every request to get the informations for your answers sin The tools defined in the MCP server are made to be used in combination with each other. All except the most trivial requests will require the use of multiple tools. Here is an example: -User> Can you give me information about San Francisco? +User> Can you give me information about other restaurants near Saison? -ToolCall> search_entity({"query": "San Francisco"}) -ToolResult> -``` -{ - "entities": [ - { - "description": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District.", - "id": "3qayfdjYyPv1dAYf8gPL5r", - "name": "San Francisco" - }, - { - "description": null, - "id": "W5ZEpuy3Tij1XSXtJLruQ5", - "name": "SF Bay Area" - }, - { - "description": null, - "id": "RHoJT3hNVaw7m5fLLtZ8WQ", - "name": "California" - }, - { - "description": null, - "id": "Sh1qtjr4i92ZD6YGPeu5a2", - "name": "Abundant housing in San Francisco" - }, - { - "description": null, - "id": "UqLf9fTVKHkDs3LzP9zHpH", - "name": "Public safety in San Francisco" - }, - { - "description": null, - "id": "BeyiZ6oLqLMaSXiG41Yxtf", - "name": "City" - }, - { - "description": null, - "id": "D6Wy4bdtdoUrG3PDZceHr", - "name": "City" - }, - { - "description": null, - "id": "JWVrgUXmjS75PqNX2hry5q", - "name": "Clean streets in San Francisco" - }, - { - "description": null, - "id": "DcA2c7ooFTgEdtaRcaj7Z1", - "name": "Revitalizing downtown San Francisco" - }, - { - "description": null, - "id": "KWBLj9czHBBmYUT98rnxVM", - "name": "Location" - } - ] -} -``` -Let's get more info about San Francisco (id: 3qayfdjYyPv1dAYf8gPL5r) +Let's get precise information about Atelier Crenn. +Atelier Crenn id: "WFYTR1pxsZNjk8p6Z2CCg4" -ToolCall> get_entity_info("3qayfdjYyPv1dAYf8gPL5r") +ToolCall> get_entity_info("WFYTR1pxsZNjk8p6Z2CCg4") ToolResult> ``` { "all_attributes": [ - { - "attribute_name": "Description", - "attribute_value": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District." - }, { "attribute_name": "Name", - "attribute_value": "San Francisco" + "attribute_value": "Atelier Crenn" } ], - "id": "3qayfdjYyPv1dAYf8gPL5r", + "id": "WFYTR1pxsZNjk8p6Z2CCg4", "inbound_relations": [ { - "id": "NAMA1uDMzBQTvPYV9N92BV", - "name": "SF Mayor Lurie launching police task force to counter crime in core downtown areas", - "relation_id": "8ESicJHiNJ28VGL5u34A5q", - "relation_type": "Related spaces" - }, - { - "id": "6wAoNdGVbweKi2JRPZP4bX", - "name": "San Francisco Independent Film Festival", - "relation_id": "TH5Tu5Y5nacvREvAQRvcR2", - "relation_type": "Related spaces" - }, - { - "id": "8VCHYDURDStwuTCUBjWLQa", - "name": "Product Engineer at Geo", - "relation_id": "KPTqdNpCusxfM37KbKPX8w", - "relation_type": "Related spaces" - }, - { - "id": "NcQ3h9jeJSavVd8iFsUxvD", - "name": "Senior Civil Engineer @ Golden Gate Bridge, Highway & Transportation District", - "relation_id": "AqpNtJ3XxaY4fqRCyoXbdt", - "relation_type": "Cities" - }, - { - "id": "4ojV4dS1pV2tRnzXTpcMKJ", - "name": "Senior Plan Check Engineer (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "3AX4j43nywT5eBRV3s6AXi", - "relation_type": "Cities" - }, - { - "id": "QoakYWCuv85FVuYdSmonxr", - "name": "Senior Civil Engineer - Land Development (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "8GEF1i3LK4Z56THjE8dVku", - "relation_type": "Cities" - }, - { - "id": "JuV7jLoypebzLhkma6oZoU", - "name": "Lead Django Backend Engineer @ Textme Inc", - "relation_id": "46aBsQyBq15DimJ2i1DX4a", - "relation_type": "Cities" - }, - { - "id": "RTmcYhLVmmfgUn9L3D1J3y", - "name": "Chief Engineer @ Wyndham Hotels & Resorts", - "relation_id": "8uYxjzkkdjskDQAeTQomvc", - "relation_type": "Cities" + "id": "T6iKbwZ17iv4dRdR9Qw7qV", + "name": "Trending restaurants", + "relation_id": "Mwrn46KavwfWgNrFaWcB9j", + "relation_type": "Collection item" } ], "outbound_relations": [ { - "id": "CUoEazCD7EmzXPTFFY8gGY", + "id": "AxW1SQEvzvuKkPV6T19VDL", "name": "No name", - "relation_id": "5WeSkkE1XXvGJGmXj9VUQ8", + "relation_id": "7YHk6qYkNDaAtNb8GwmysF", "relation_type": "Cover" }, { - "id": "7gzF671tq5JTZ13naG4tnr", - "name": "Space", - "relation_id": "WUZCXE1UGRtxdNQpGug8Tf", + "id": "A9QizqoXSqjfPUBjLoPJa2", + "name": "Restaurant", + "relation_id": "Jfmby78N4BCseZinBmdVov", "relation_type": "Types" - }, - { - "id": "D6Wy4bdtdoUrG3PDZceHr", - "name": "City", - "relation_id": "ARMj8fjJtdCwbtZa1f3jwe", - "relation_type": "Types" - }, - { - "id": "AhidiWYnQ8fAbHqfzdU74k", - "name": "Upcoming events", - "relation_id": "V1ikGW9riu7dAP8rMgZq3u", - "relation_type": "Blocks" - }, - { - "id": "T6iKbwZ17iv4dRdR9Qw7qV", - "name": "Trending restaurants", - "relation_id": "CvGXCmGXE7ofsgZeWad28p", - "relation_type": "Blocks" - }, - { - "id": "X18WRE36mjwQ7gu3LKaLJS", - "name": "Neighborhoods", - "relation_id": "Uxpsee9LoTgJqMFfAQyJP6", - "relation_type": "Blocks" - }, - { - "id": "HeC2pygci2tnvjTt5aEnBV", - "name": "Top goals", - "relation_id": "5WMTAzCnZH9Bsevou9GQ3K", - "relation_type": "Blocks" - }, - { - "id": "5YtYFsnWq1jupvh5AjM2ni", - "name": "Culture", - "relation_id": "5TmxfepRr1THMRkGWenj5G", - "relation_type": "Tabs" } ] } diff --git a/mcp-server/resources/name_search_entity_description.md b/mcp-server/resources/name_search_entity_description.md index 8377684..2d13b3b 100644 --- a/mcp-server/resources/name_search_entity_description.md +++ b/mcp-server/resources/name_search_entity_description.md @@ -1,43 +1,5 @@ This request allows you to get Entities from a name/description search and traversal from that query by using relation name. -Example Query: Find employees that works at The Graph. - -ToolCall> -``` -name_search_entity( - { - "query": "The Graph", - "traversal_filter": { - "relation_type_id": "Works at", - "direction": "From" - } - } -) -``` - -ToolResult> -``` -{ - "entities": [ - { - "description": "Founder & CEO of Geo. Cofounder of The Graph, Edge & Node, House of Web3. Building a vibrant decentralized future.", - "id": "9HsfMWYHr9suYdMrtssqiX", - "name": "Yaniv Tal" - }, - { - "description": "Developer Relations Engineer", - "id": "22MGz47c9WHtRiHuSEPkcG", - "name": "Kevin Jones" - }, - { - "description": "Description will go here", - "id": "JYTfEcdmdjiNzBg469gE83", - "name": "Pedro Diogo" - } - ] -} -``` - Example Query: Find all the articles written by employees that works at The Graph. ToolCall> @@ -56,7 +18,6 @@ name_search_entity( } ) ``` - ToolResult> ``` { diff --git a/mcp-server/resources/search_entity_description.md b/mcp-server/resources/search_entity_description.md index 6ba6e33..75beb58 100644 --- a/mcp-server/resources/search_entity_description.md +++ b/mcp-server/resources/search_entity_description.md @@ -1,73 +1,6 @@ This request allows you to get the Entities from a name/description search and traversal from that query if needed. - -Example Query: Can you give me information about San Francisco? - -ToolCall> -``` -search_entity({ -"query": "San Francisco" -}) -``` -Tool Result> -``` -{ - "entities": [ - { - "description": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District.", - "id": "3qayfdjYyPv1dAYf8gPL5r", - "name": "San Francisco" - }, - { - "description": null, - "id": "W5ZEpuy3Tij1XSXtJLruQ5", - "name": "SF Bay Area" - }, - { - "description": null, - "id": "RHoJT3hNVaw7m5fLLtZ8WQ", - "name": "California" - }, - { - "description": null, - "id": "Sh1qtjr4i92ZD6YGPeu5a2", - "name": "Abundant housing in San Francisco" - }, - { - "description": null, - "id": "UqLf9fTVKHkDs3LzP9zHpH", - "name": "Public safety in San Francisco" - }, - { - "description": null, - "id": "BeyiZ6oLqLMaSXiG41Yxtf", - "name": "City" - }, - { - "description": null, - "id": "D6Wy4bdtdoUrG3PDZceHr", - "name": "City" - }, - { - "description": null, - "id": "JWVrgUXmjS75PqNX2hry5q", - "name": "Clean streets in San Francisco" - }, - { - "description": null, - "id": "DcA2c7ooFTgEdtaRcaj7Z1", - "name": "Revitalizing downtown San Francisco" - }, - { - "description": null, - "id": "KWBLj9czHBBmYUT98rnxVM", - "name": "Location" - } - ] -} -``` - -Another Query: Give me the employees that work at The Graph? +Example Query: Give me the employees that work at The Graph? Work_at id: U1uCAzXsRSTP4vFwo1JwJG ToolCall> diff --git a/mcp-server/resources/search_relation_type_description.md b/mcp-server/resources/search_relation_type_description.md index 02134b4..a804b21 100644 --- a/mcp-server/resources/search_relation_type_description.md +++ b/mcp-server/resources/search_relation_type_description.md @@ -15,28 +15,6 @@ ToolResult> "attribute_value": "0", "entity_id": "U1uCAzXsRSTP4vFwo1JwJG" } - ], - [ - { - "attribute_name": "Name", - "attribute_value": "Worked at", - "entity_id": "8fvqALeBDwEExJsDeTcvnV" - }, - { - "attribute_name": "Is type property", - "attribute_value": "0", - "entity_id": "8fvqALeBDwEExJsDeTcvnV" - }, - { - "attribute_name": "Name", - "attribute_value": "Worked at", - "entity_id": "8fvqALeBDwEExJsDeTcvnV" - }, - { - "attribute_name": "Description", - "attribute_value": "A project that someone worked at in the past. Details about the role can be added as properties on the relation.", - "entity_id": "8fvqALeBDwEExJsDeTcvnV" - } ] ] ``` diff --git a/mcp-server/resources/search_space_description.md b/mcp-server/resources/search_space_description.md index 3224922..8f442ba 100644 --- a/mcp-server/resources/search_space_description.md +++ b/mcp-server/resources/search_space_description.md @@ -19,20 +19,6 @@ ToolResult> "attribute_value": "San Francisco", "entity_id": "3qayfdjYyPv1dAYf8gPL5r" } - ], - [ - { - "attribute_name": "Name", - "attribute_value": "SF Bay Area", - "entity_id": "W5ZEpuy3Tij1XSXtJLruQ5" - } - ], - [ - { - "attribute_name": "Name", - "attribute_value": "California", - "entity_id": "RHoJT3hNVaw7m5fLLtZ8WQ" - } ] ] ``` diff --git a/mcp-server/resources/search_type_description.md b/mcp-server/resources/search_type_description.md index 478367f..53cf963 100644 --- a/mcp-server/resources/search_type_description.md +++ b/mcp-server/resources/search_type_description.md @@ -15,18 +15,6 @@ ToolResult> "attribute_value": "University", "entity_id": "L8iozarUyS8bkcUiS6kPqV" } - ], - [ - { - "attribute_name": "Description", - "attribute_value": "An educational institution where students acquire knowledge, skills, and credentials through structured learning programs.", - "entity_id": "M89C7wwdJVaCW9rAVQpJbY" - }, - { - "attribute_name": "Name", - "attribute_value": "School", - "entity_id": "M89C7wwdJVaCW9rAVQpJbY" - } ] ] ``` diff --git a/mcp-server/src/main.rs b/mcp-server/src/main.rs index 3147abe..8231e58 100644 --- a/mcp-server/src/main.rs +++ b/mcp-server/src/main.rs @@ -113,87 +113,7 @@ impl KnowledgeGraph { RawResource::new(uri, name.to_string()).no_annotation() } - async fn search( - &self, - query: String, - limit: Option, - ) -> Result, McpError> { - let embedding = self - .embedding_model - .embed(vec![&query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); - - let limit = limit.unwrap_or(10); - - let semantic_search_triples = triple::search(&self.neo4j, embedding) - .limit(limit) - .send() - .await - .map_err(|e| { - McpError::internal_error( - "search_types_failed", - Some(json!({ "error": e.to_string() })), - ) - })? - .try_collect::>() - .await - .map_err(|e| { - McpError::internal_error( - "search_types_failed", - Some(json!({ "error": e.to_string() })), - ) - })?; - Ok(semantic_search_triples) - } - - async fn get_ids_from_search( - &self, - search_triples: Vec, - create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, - ) -> Result, McpError> { - let mut seen_ids: HashSet = HashSet::new(); - let mut result_ids: Vec = Vec::new(); - - for semantic_search_triple in search_triples { - let filtered_for_types = relation::find_many::>(&self.neo4j) - .filter(create_relation_filter(semantic_search_triple)) - .send() - .await; - - //We only need to get the first relation since they would share the same entity id - if let Ok(stream) = filtered_for_types { - pin_mut!(stream); - if let Some(edge) = stream.try_next().await.ok().flatten() { - let id = edge.from.id; - if seen_ids.insert(id.clone()) { - result_ids.push(id); - } - } - } - } - Ok(result_ids) - } - - async fn format_triples_detailled( - &self, - triples: Result, ErrorData>, - ) -> Vec { - if let Ok(triples) = triples { - join_all(triples.into_iter().map(|triple| async move {json!({ - "entity_id": triple.entity, - "attribute_name": self.get_name_of_id(triple.attribute).await.unwrap_or("No attribute name".to_string()), - "attribute_value": String::try_from(triple.value).unwrap_or("No value".to_string()) - })})).await.to_vec() - } else { - Vec::new() - } - } - + // Tools that are available in the mcp server #[tool(description = include_str!("../resources/search_type_description.md"))] async fn search_types( &self, @@ -454,15 +374,7 @@ impl KnowledgeGraph { search_traversal_filter.query ); - let embedding = self - .embedding_model - .embed(vec![&search_traversal_filter.query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); + let embedding = self.create_embedding(vec![&search_traversal_filter.query]); let traversal_filters: Vec<_> = search_traversal_filter .traversal_filter @@ -535,15 +447,7 @@ impl KnowledgeGraph { ) -> Result { tracing::info!("SearchTraversalFilter query: {:?}", search_traversal_filter); - let embedding = self - .embedding_model - .embed(vec![&search_traversal_filter.query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); + let embedding = self.create_embedding(vec![&search_traversal_filter.query]); let traversal_filters: Vec> = match search_traversal_filter.traversal_filter { @@ -725,7 +629,7 @@ impl KnowledgeGraph { .into_iter() .map(|result| async move { json!({ - "relation_id": result.id, + "relation_id": result.relation_type, "relation_type": self.get_name_of_id(result.relation_type).await.unwrap_or("No relation type".to_string()), "id": if is_inbound {result.from.id.clone()} else {result.to.id.clone()}, "name": self.get_name_of_id(if is_inbound {result.from.id.clone()} else {result.to.id.clone()}).await.unwrap_or("No name".to_string()), @@ -797,6 +701,91 @@ impl KnowledgeGraph { )) } + async fn search( + &self, + query: String, + limit: Option, + ) -> Result, McpError> { + let embedding = self.create_embedding(vec![&query]); + + let limit = limit.unwrap_or(10); + + let semantic_search_triples = triple::search(&self.neo4j, embedding) + .limit(limit) + .send() + .await + .map_err(|e| { + McpError::internal_error( + "search_types_failed", + Some(json!({ "error": e.to_string() })), + ) + })? + .try_collect::>() + .await + .map_err(|e| { + McpError::internal_error( + "search_types_failed", + Some(json!({ "error": e.to_string() })), + ) + })?; + Ok(semantic_search_triples) + } + + async fn get_ids_from_search( + &self, + search_triples: Vec, + create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, + ) -> Result, McpError> { + let mut seen_ids: HashSet = HashSet::new(); + let mut result_ids: Vec = Vec::new(); + + for semantic_search_triple in search_triples { + let filtered_for_types = relation::find_many::>(&self.neo4j) + .filter(create_relation_filter(semantic_search_triple)) + .send() + .await; + + //We only need to get the first relation since they would share the same entity id + if let Ok(stream) = filtered_for_types { + pin_mut!(stream); + if let Some(edge) = stream.try_next().await.ok().flatten() { + let id = edge.from.id; + if seen_ids.insert(id.clone()) { + result_ids.push(id); + } + } + } + } + Ok(result_ids) + } + + async fn format_triples_detailled( + &self, + triples: Result, ErrorData>, + ) -> Vec { + if let Ok(triples) = triples { + join_all(triples.into_iter().map(|triple| async move {json!({ + "entity_id": triple.entity, + "attribute_name": self.get_name_of_id(triple.attribute).await.unwrap_or("No attribute name".to_string()), + "attribute_value": String::try_from(triple.value).unwrap_or("No value".to_string()) + })})).await.to_vec() + } else { + Vec::new() + } + } + + #[inline(always)] + fn create_embedding(&self, name: Vec<&str>) -> Vec { + self.embedding_model + .embed(name, None) + .expect("Failed to get embedding") + .pop() + .expect("No embedding found") + .into_iter() + .map(|v| v as f64) + .collect::>() + } + async fn get_name_of_id(&self, id: String) -> Result { let entity = entity::find_one::>(&self.neo4j, &id) .send() @@ -807,7 +796,9 @@ impl KnowledgeGraph { .ok_or_else(|| { McpError::internal_error("entity_name_not_found", Some(json!({ "id": id }))) })?; - Ok(entity.attributes.name.unwrap_or("No name".to_string())) + entity.attributes.name.ok_or_else(|| { + McpError::internal_error("entity_name_not_found", Some(json!({ "id": id }))) + }) } } @@ -838,53 +829,6 @@ impl ServerHandler for KnowledgeGraph { } Ok(self.get_info()) } - - //TODO: make prompt examples to use on data - async fn list_prompts( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListPromptsResult { - next_cursor: None, - prompts: vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required argument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )], - }) - } - - async fn get_prompt( - &self, - GetPromptRequestParam { name, arguments }: GetPromptRequestParam, - _: RequestContext, - ) -> Result { - match name.as_str() { - "example_prompt" => { - let message = arguments - .and_then(|json| json.get("message")?.as_str().map(|s| s.to_string())) - .ok_or_else(|| { - McpError::invalid_params("No message provided to example_prompt", None) - })?; - - let prompt = - format!("This is an example prompt with your message here: '{message}'"); - Ok(GetPromptResult { - description: None, - messages: vec![PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::text(prompt), - }], - }) - } - _ => Err(McpError::invalid_params("prompt not found", None)), - } - } } #[derive(Debug, Parser)] From dc4cdeeef0efda6929caa50b80b85fa0e6877ebf Mon Sep 17 00:00:00 2001 From: tourtourigny Date: Wed, 20 Aug 2025 16:54:18 -0400 Subject: [PATCH 2/3] fix update README --- ai-search/README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ai-search/README.md b/ai-search/README.md index 270f192..458a973 100644 --- a/ai-search/README.md +++ b/ai-search/README.md @@ -1,12 +1,15 @@ # README ai-search ## API -ai-search is a REST API that can be queried at the adress 127.0.0.0:3000. -The only available route is /question that takes a JSON containing the question for the knowledge graph. +ai-search is a REST API that can be queried at the adress 0.0.0.0:3000. +The are 2 available routes to query the knowledge graph using natural language. + +The first route is /question that takes a question in the body as a String and give the anser back quicker. It uses less Gemini calls and is faster. +The second route is /question_ai that takes a question in the body as a String and gives an answer back. It rellies more on Gemini calls and is slower, but has more precision. ## Start command ``` -cargo run --bin ai-search -- --neo4j-uri neo4j://localhost:7687 --neo4j-user neo4j --neo4j-pass neo4j --gemini-api-key +cargo run --bin ai-search -- --neo4j-uri neo4j://localhost:7687 --neo4j-user neo4j --neo4j-pass neo4j --gemini-api-key ``` \ No newline at end of file From e50c45d27aba4730e11726017beb3468ede1dc23 Mon Sep 17 00:00:00 2001 From: tourtourigny Date: Wed, 20 Aug 2025 17:19:09 -0400 Subject: [PATCH 3/3] fix clippy error --- ai-search/src/automatic_search.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ai-search/src/automatic_search.rs b/ai-search/src/automatic_search.rs index e9a47ed..9eebb86 100644 --- a/ai-search/src/automatic_search.rs +++ b/ai-search/src/automatic_search.rs @@ -475,10 +475,10 @@ impl AutomaticSearchAgent { tracing::error!("Error: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; - if let Some(name_entity) = triple_result.first() { - if let Some(embedding) = &name_entity.embedding { - return Ok((embedding.clone(), name_entity.value.value.clone())); - } + if let Some(name_entity) = triple_result.first() + && let Some(embedding) = &name_entity.embedding + { + return Ok((embedding.clone(), name_entity.value.value.clone())); } Err(StatusCode::INTERNAL_SERVER_ERROR) }