diff --git a/Cargo.lock b/Cargo.lock index 1bafabf..81f2b70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai-search" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "clap", + "fastembed", + "futures", + "grc20-core", + "grc20-sdk", + "rand 0.8.5", + "regex", + "rig-core", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "aligned-vec" version = "0.5.0" @@ -151,6 +172,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "as-any" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063" + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -669,7 +696,7 @@ dependencies = [ "serde_json", "serde_repr", "serde_urlencoded", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-util", "tower-service", @@ -720,9 +747,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" dependencies = [ "serde", ] @@ -1632,6 +1659,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "gloo-timers" version = "0.3.0" @@ -1662,7 +1695,7 @@ dependencies = [ "serde_json", "serde_with", "testcontainers", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tracing", "uuid", @@ -1798,10 +1831,10 @@ dependencies = [ "log", "native-tls", "rand 0.8.5", - "reqwest 0.12.12", + "reqwest 0.12.22", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", "ureq", "windows-sys 0.59.0", ] @@ -2032,21 +2065,28 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", + "futures-core", "futures-util", "http 1.2.0", "http-body 1.0.1", "hyper 1.6.0", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration 0.6.1", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -2331,8 +2371,8 @@ name = "ipfs" version = "0.1.0" dependencies = [ "prost", - "reqwest 0.12.12", - "thiserror 2.0.11", + "reqwest 0.12.22", + "thiserror 2.0.12", "tokio", "tracing", ] @@ -2343,6 +2383,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2568,9 +2618,9 @@ checksum = "db13adb97ab515a3691f56e4dbab09283d0b86cb45abd991d8634a9d6f501760" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libfuzzer-sys" @@ -2769,6 +2819,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3087,6 +3147,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" +dependencies = [ + "num-traits", +] + [[package]] name = "ort" version = "2.0.0-rc.9" @@ -3762,9 +3831,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "cbc931937e6ca3a06e3b6c0aa7841849b160a90351d6ab467a8b9b9959767531" dependencies = [ "base64 0.22.1", "bytes", @@ -3779,31 +3848,29 @@ dependencies = [ "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", - "ipnet", "js-sys", "log", "mime", + "mime_guess", "native-tls", - "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", - "system-configuration 0.6.1", "tokio", "tokio-native-tls", "tokio-util", "tower 0.5.2", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "windows-registry", ] [[package]] @@ -3818,6 +3885,29 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +[[package]] +name = "rig-core" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e57d16b6bc8ed1a4d7cea51782f1c6320e0a98b01a8de719540251a81b027ffb" +dependencies = [ + "as-any", + "async-stream", + "base64 0.22.1", + "bytes", + "futures", + "glob", + "mime_guess", + "ordered-float", + "reqwest 0.12.22", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.12", + "tracing", + "url", +] + [[package]] name = "ring" version = "0.17.8" @@ -3849,7 +3939,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-stream", "tokio-util", @@ -4294,7 +4384,7 @@ dependencies = [ "prometheus", "prost", "prost-types", - "reqwest 0.12.12", + "reqwest 0.12.22", "serde", "serde_json", "serde_path_to_error", @@ -4302,7 +4392,7 @@ dependencies = [ "substreams-utils", "test-log", "testcontainers", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tracing", "tracing-appender", @@ -4355,9 +4445,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", @@ -4663,7 +4753,7 @@ dependencies = [ "serde", "serde_json", "serde_with", - "thiserror 2.0.11", + "thiserror 2.0.12", "tokio", "tokio-stream", "tokio-tar", @@ -4682,11 +4772,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -4702,9 +4792,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -4807,7 +4897,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror 2.0.11", + "thiserror 2.0.12", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -5019,14 +5109,18 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.2" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ "bitflags 2.8.0", "bytes", + "futures-util", "http 1.2.0", + "http-body 1.0.1", + "iri-string", "pin-project-lite", + "tower 0.5.2", "tower-layer", "tower-service", ] @@ -5160,6 +5254,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.16" @@ -5525,32 +5625,31 @@ checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] name = "windows-registry" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "b3bab093bdd303a1240bb99b8aba8ea8a69ee19d34c9e2ef9594e708a4878820" dependencies = [ + "windows-link", "windows-result", "windows-strings", - "windows-targets 0.52.6", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-result", - "windows-targets 0.52.6", + "windows-link", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 29e4a9a..277fcaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,5 +8,5 @@ members = [ "web3-utils", "grc20-core", "grc20-macros", - "grc20-sdk", "mcp-server", + "grc20-sdk", "mcp-server", "ai-search", ] diff --git a/ai-search/Cargo.toml b/ai-search/Cargo.toml new file mode 100644 index 0000000..76f7ddf --- /dev/null +++ b/ai-search/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "ai-search" +version = "0.1.0" +edition = "2024" + +[dependencies] +axum= "0.8.4" +anyhow = "1.0.89" +clap = { version = "4.5.39", features = ["derive", "env"] } +fastembed = "4.8.0" +futures = "0.3.31" +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } + +grc20-core = { version = "0.1.0", path = "../grc20-core" } +grc20-sdk = { version = "0.1.0", path = "../grc20-sdk" } + +rand = "0.8" +regex = "1.10" +rig-core = "0.17.1" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/ai-search/README.md b/ai-search/README.md new file mode 100644 index 0000000..458a973 --- /dev/null +++ b/ai-search/README.md @@ -0,0 +1,15 @@ +# README ai-search + +## API +ai-search is a REST API that can be queried at the adress 0.0.0.0:3000. +The are 2 available routes to query the knowledge graph using natural language. + +The first route is /question that takes a question in the body as a String and give the anser back quicker. It uses less Gemini calls and is faster. +The second route is /question_ai that takes a question in the body as a String and gives an answer back. It rellies more on Gemini calls and is slower, but has more precision. + + +## Start command +``` +cargo run --bin ai-search -- --neo4j-uri neo4j://localhost:7687 --neo4j-user neo4j --neo4j-pass neo4j --gemini-api-key + +``` \ No newline at end of file diff --git a/ai-search/ressources/traversal_prompt.md b/ai-search/ressources/traversal_prompt.md new file mode 100644 index 0000000..8d372eb --- /dev/null +++ b/ai-search/ressources/traversal_prompt.md @@ -0,0 +1,11 @@ +You are a Knowledge Graph traversal agent. Your goal is to answer the user’s question using the fewest steps possible. +You may: + +ID to discover an unexplored neighbor’s details. + + to give a partial answer if you have some important information. + + if no path can yield the answer. + +Never re-explore a node. Always choose the neighbor(s) most likely to contain the answer or get some more relevant information. Try to expand until you have an answer. You can explore multiple nodes and expansions at the same time with multiple explore. + diff --git a/ai-search/src/automatic_search.rs b/ai-search/src/automatic_search.rs new file mode 100644 index 0000000..9eebb86 --- /dev/null +++ b/ai-search/src/automatic_search.rs @@ -0,0 +1,507 @@ +use axum::{extract::Json, http::StatusCode}; +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use futures::{TryStreamExt, future::join_all, pin_mut}; +use grc20_core::{ + entity::{self, Entity, EntityFilter, EntityNode}, + mapping::{ + Query, QueryStream, RelationEdge, prop_filter, + triple::{self, SemanticSearchResult}, + }, + neo4rs, + relation::{self, RelationFilter}, + system_ids, +}; +use grc20_sdk::models::BaseEntity; +use regex::Regex; +use rig::{agent::Agent, completion::Prompt, providers::gemini::completion::CompletionModel}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +const EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::AllMiniLML6V2; + +pub struct AutomaticSearchAgent { + neo4j: neo4rs::Graph, + pub embedding_model: Arc, + pub fast_agent: Agent, + pub thinking_agent: Agent, +} + +impl AutomaticSearchAgent { + pub fn new(neo4j: neo4rs::Graph, gemini_api_key: &str) -> Self { + let preamble = "You are a knowledge graph helper to answer natural language questions."; + let gemini_client = rig::providers::gemini::Client::new(gemini_api_key); + + Self { + neo4j, + embedding_model: Arc::new( + TextEmbedding::try_new( + InitOptions::new(EMBEDDING_MODEL).with_show_download_progress(true), + ) + .expect("Failed to initialize embedding model"), + ), + fast_agent: gemini_client + .agent("gemini-2.5-flash-lite") + .temperature(0.4) + .preamble(preamble) + .build(), + thinking_agent: gemini_client + .agent("gemini-2.5-flash") + .temperature(0.4) + .preamble(preamble) + .build(), + } + } + + pub async fn natural_language_question( + &self, + question: String, + ) -> Result, StatusCode> { + let string_extraction_regex = Regex::new(r#""([^"]+)""#).unwrap(); + + let create_relation_filter = |search_result: SemanticSearchResult| { + RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(search_result.triple.entity))) + }; + + let main_entity = self.fast_agent.prompt(format!("Can you extract the main 1-3 Entities from which you can do a search from the original question. THE ANSWER SHOULD BE A SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. IF THERE IS MORE THAN ONE IMPORTANT CONCEPT YOU CAN EXTRACT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let entities: Vec = string_extraction_regex + .captures_iter(&main_entity) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + let number_answers = entities.len(); + + tracing::info!("important word found: {entities:?}"); + + let important_concepts = self.fast_agent.prompt(format!("Can you extract the all keywords and important concepts from the original question. You can also add keywords related that aren't directly in the question. THE ANSWER SHOULD BE A SERIE OF SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. OUTPUT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let concepts: Vec = string_extraction_regex + .captures_iter(&important_concepts) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + tracing::info!("important concepts are: {concepts:?}"); + + let concepts_embeddings: Vec> = concepts + .into_iter() + .map(|concept| { + self.embedding_model + .embed(vec![concept], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>() + }) + .collect(); + + let expanded_paths = join_all( + join_all(entities.into_iter().map(|entity| async { + let semantic_search = self.search(entity, Some(1)).await.unwrap_or_default(); + self.get_ids_from_search(semantic_search, create_relation_filter) + .await + .unwrap_or_default() + })) + .await + .into_iter() + .flatten() + .map(|sart_node_id| async { + self.automatic_explore_node(sart_node_id, concepts_embeddings.clone()) + .await + }), + ) + .await + .to_vec(); + + let answers_combined = join_all(expanded_paths.iter().map(|path| async { + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe full search of nodes:{}", path.clone() + ); + + self.thinking_agent + .prompt(agent_prompt) + .await.unwrap_or("Agent error".to_string()) + })).await.join("]\n["); + + let final_answer = if number_answers == 1 { + answers_combined + } else { + let final_answer_prompt = format!( + "BASE YOUR ANSWER ONLY ON THE PARTIAL ANSWERS! YOU ANSWER ONLY THE ORIGINAL QUESTION SINCE THE MAIN USER DOESN'T CARE ABOUT PARTIAL ANSWERS. GIVE A FULL COMPLETE AND DETAILLED ANSWER. From the given partial answer from different starting points, can you provide a final single answer to the original question: {question}\n Here are the different partial answers:[{answers_combined}]" + ); + self.fast_agent + .prompt(final_answer_prompt) + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + }; + + Ok(Json(final_answer)) + } + + async fn automatic_explore_node( + &self, + start_id: String, + concepts_embeddings: Vec>, + ) -> String { + let mut seen_ids = HashSet::new(); + + let start_name = self + .get_name_of_id(start_id.clone()) + .await + .unwrap_or("No name".to_string()); + + let mut path_root = PathNode { + id: start_id.clone(), + relation_name: "Start".to_string(), + relation_name_embedding: Vec::new(), + entity_name: start_name, + entity_name_embedding: Vec::new(), + nodes: Vec::new(), + is_inbound: false, + page_rank: 0.0, + depth: 0, + }; + + let mut stack: Vec<&mut PathNode> = vec![&mut path_root]; + + while let Some(current) = stack.pop() { + if !seen_ids.insert(current.id.clone()) { + continue; + } + + let neighbor_nodes = self + .automatic_expand_entity( + current.id.clone(), + current.depth, + concepts_embeddings.clone(), + ) + .await + .unwrap_or_default(); + + for neighbor in neighbor_nodes { + current.nodes.push(PathNode { + nodes: Vec::new(), + id: neighbor.id.clone(), + depth: neighbor.depth + 1, + ..neighbor + }); + } + + current + .nodes + .iter_mut() + .for_each(|neighbor| stack.insert(0, neighbor)); + } + + self.create_path(&path_root, 0).await + } + + async fn automatic_expand_entity( + &self, + base_id: String, + depth: usize, + concepts_embeddings: Vec>, + ) -> Result, StatusCode> { + const MAX_SAME_RELATION: usize = 10; + const NAME_IMPORTANCE_FACTOR: f64 = 0.9; + const MAX_DISPLAYED_NEIGHBORS: usize = 35; + const EMBEDDING_DISTANCE_THRESHOLD: f64 = 0.65; + + let mut relation_count = HashMap::new(); + + let mut relations = self + .extract_path_nodes_to_neighbors(base_id.clone(), depth, false) + .await + .unwrap_or_default(); + relations.extend( + self.extract_path_nodes_to_neighbors(base_id.clone(), depth, true) + .await + .unwrap_or_default(), + ); + + let mut relations_with_scores: Vec<(f64, PathNode)> = relations + .into_iter() + .map(|path_node| { + // Compute all cosine distances to concept embeddings and take the best (lowest) + let distance_score = concepts_embeddings + .iter() + .map(|concept| { + let name_dist = + self.cosine_distance_unit(concept, &path_node.entity_name_embedding); + let rel_dist = + self.cosine_distance_unit(concept, &path_node.relation_name_embedding); + name_dist.min(rel_dist) * NAME_IMPORTANCE_FACTOR + + name_dist.max(rel_dist) * (1.0 - NAME_IMPORTANCE_FACTOR) + }) + .fold(f64::INFINITY, |a, b| a.min(b)); + + let depth_penalty = 0.02 * depth as f64; + let depth_centered = 3.5 - depth as f64; + let depth_factor = depth_centered.powf(1.0 / 5.0) * 20.0; + let page_rank_clipped = path_node.page_rank.clamp(-0.5, 0.8); + let final_score = + depth_penalty + distance_score * (0.9 - page_rank_clipped / depth_factor); // /depth_factor + + (final_score, path_node) + }) + .collect(); + + relations_with_scores + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Less)); + + let relations: Vec = relations_with_scores + .into_iter() + .filter(|(score, _)| score < &EMBEDDING_DISTANCE_THRESHOLD) + .filter_map(|(_, path_node)| { + let count = relation_count + .entry(path_node.relation_name.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + if *count <= MAX_SAME_RELATION { + Some(path_node) + } else { + None + } + }) + .take(MAX_DISPLAYED_NEIGHBORS) + .collect(); + + Ok(relations) + } + + async fn search( + &self, + query: String, + limit: Option, + ) -> Result, StatusCode> { + let embedding = self + .embedding_model + .embed(vec![&query], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>(); + + let limit = limit.unwrap_or(10); + + let semantic_search_triples = triple::search(&self.neo4j, embedding) + .limit(limit) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(semantic_search_triples) + } + + async fn get_ids_from_search( + &self, + search_triples: Vec, + create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, + ) -> Result, StatusCode> { + let mut seen_ids: HashSet = HashSet::new(); + let mut result_ids: Vec = Vec::new(); + + for semantic_search_triple in search_triples { + let filtered_for_types = relation::find_many::>(&self.neo4j) + .filter(create_relation_filter(semantic_search_triple)) + .send() + .await; + + //We only need to get the first relation since they would share the same entity id + if let Ok(stream) = filtered_for_types { + pin_mut!(stream); + if let Some(edge) = stream.try_next().await.ok().flatten() { + let id = edge.from.id; + if seen_ids.insert(id.clone()) { + result_ids.push(id); + } + } + } + } + Ok(result_ids) + } + + async fn extract_path_nodes_to_neighbors( + &self, + id: String, + depth: usize, + is_inbound: bool, + ) -> Result, StatusCode> { + const MAX_RELATIONS_CONSIDERED: usize = 100; + + let in_filter = relation::RelationFilter::default() + .to_(EntityFilter::default().id(prop_filter::value(id.clone()))); + let out_filter = relation::RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(id.clone()))); + + let relations = relation::find_many::>(&self.neo4j) + .filter(if is_inbound { in_filter } else { out_filter }) + .limit(MAX_RELATIONS_CONSIDERED) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let related_nodes = join_all(relations.into_iter().map(|result| async move { + let page_rank: f64; + let id: &str; + + if is_inbound { + id = &result.from.id; + page_rank = result.from.system_properties.page_rank.unwrap_or(0.0); + } else { + id = &result.to.id; + page_rank = result.to.system_properties.page_rank.unwrap_or(0.0); + }; + let (entity_name_embedding, entity_name) = self + .get_name_and_embedding_of_id(id.to_string()) + .await + .unwrap_or((Vec::new(), "No name".to_string())); + let (relation_name_embedding, relation_name) = self + .get_name_and_embedding_of_id(result.relation_type.clone()) + .await + .unwrap_or((Vec::new(), "No name".to_string())); + + PathNode { + id: id.to_string(), + relation_name, + relation_name_embedding, + entity_name, + entity_name_embedding, + nodes: Vec::new(), + is_inbound, + page_rank, + depth: depth + 1, + } + })) + .await + .to_vec(); + + Ok(related_nodes) + } + + async fn create_path(&self, path_node: &PathNode, depth: usize) -> String { + const MAX_DISPLAY_LENGTH: usize = 8; + + let shorten_string = |word: String, max_length: usize| { + let splitted: Vec<&str> = word.split_whitespace().collect(); + if splitted.len() > max_length { + splitted[..max_length].join(" ") + "..." + } else { + splitted.join(" ") + } + }; + + let mut base = format!( + "{}{}-[{}]-{}{}\n", + " ".repeat(depth), + if path_node.is_inbound { "<" } else { "" }, + shorten_string(path_node.relation_name.clone(), MAX_DISPLAY_LENGTH), + if path_node.is_inbound { "" } else { ">" }, + shorten_string(path_node.entity_name.clone(), MAX_DISPLAY_LENGTH), + ); + + let rest = join_all( + path_node + .nodes + .iter() + .map(|node| async { self.create_path(node, depth + 1).await }), + ) + .await + .join(""); + + base.push_str(&rest); + base + } + + async fn get_name_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + entity + .attributes + .name + .ok_or(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn get_name_and_embedding_of_id( + &self, + id: String, + ) -> Result<(Vec, String), StatusCode> { + let triple_result = triple::find_many(&self.neo4j) + .entity_id(prop_filter::value(&id)) + .attribute_id(prop_filter::value(system_ids::NAME_ATTRIBUTE)) + .limit(1) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + if let Some(name_entity) = triple_result.first() + && let Some(embedding) = &name_entity.embedding + { + return Ok((embedding.clone(), name_entity.value.value.clone())); + } + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + + #[inline(always)] + fn cosine_distance_unit(&self, a: &[f64], b: &[f64]) -> f64 { + let mut dot = 0.0; + for (&x, &y) in a.iter().zip(b.iter()) { + dot += x * y; + } + 1.0 - dot + } +} + +#[derive(Clone, Debug)] +struct PathNode { + id: String, + relation_name: String, + relation_name_embedding: Vec, + entity_name: String, + entity_name_embedding: Vec, + nodes: Vec, + is_inbound: bool, + page_rank: f64, + depth: usize, +} diff --git a/ai-search/src/full_ai_search.rs b/ai-search/src/full_ai_search.rs new file mode 100644 index 0000000..2d96228 --- /dev/null +++ b/ai-search/src/full_ai_search.rs @@ -0,0 +1,665 @@ +use anyhow::Error; +use axum::{extract::Json, http::StatusCode}; +use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use futures::{TryStreamExt, future::join_all, pin_mut}; +use grc20_core::{ + entity::{self, Entity, EntityFilter, EntityNode}, + mapping::{ + Query, QueryStream, RelationEdge, prop_filter, + triple::{self, SemanticSearchResult}, + }, + neo4rs, + relation::{self, RelationFilter}, +}; +use grc20_sdk::models::BaseEntity; +use rand::{Rng, distributions::Alphanumeric}; +use regex::Regex; +use rig::{agent::Agent, completion::Prompt, providers::gemini::completion::CompletionModel}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; +use tokio::sync::Mutex; + +const EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::AllMiniLML6V2; + +pub struct FullAISearchAgent { + neo4j: neo4rs::Graph, + pub embedding_model: Arc, + pub fast_agent: Agent, + pub thinking_agent: Agent, +} + +#[derive(Clone, Debug)] +struct PathNode { + id: String, + relation_name: String, + entity_name: String, + nodes: Vec, + is_explored: bool, + is_inbound: bool, + is_hidden: bool, + depth: usize, +} + +type PathRef = Arc>; + +impl FullAISearchAgent { + pub fn new(neo4j: neo4rs::Graph, gemini_api_key: &str) -> Self { + let traversal_system_prompt = include_str!("../ressources/traversal_prompt.md"); + let preamble = "You are a knowledge graph helper to answer natural language questions."; + let gemini_client = rig::providers::gemini::Client::new(gemini_api_key); + + Self { + neo4j, + embedding_model: Arc::new( + TextEmbedding::try_new( + InitOptions::new(EMBEDDING_MODEL).with_show_download_progress(true), + ) + .expect("Failed to initialize embedding model"), + ), + fast_agent: gemini_client + .agent("gemini-2.5-flash-lite") + .temperature(0.4) + .preamble(traversal_system_prompt) + .build(), + thinking_agent: gemini_client + .agent("gemini-2.5-flash") + .temperature(0.4) + .preamble(preamble) + .build(), + } + } + + pub async fn natural_language_question( + &self, + question: String, + ) -> Result, StatusCode> { + let string_extraction_regex = Regex::new(r#""([^"]+)""#).unwrap(); + + let create_relation_filter = |search_result: SemanticSearchResult| { + RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(search_result.triple.entity))) + }; + + let main_entity = self.thinking_agent.prompt(format!("Can you extract the main 1-3 Entities from which you can do a search from the original question. THE ANSWER SHOULD BE A SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. IF THERE IS MORE THAN ONE IMPORTANT CONCEPT YOU CAN EXTRACT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let entities: Vec = string_extraction_regex + .captures_iter(&main_entity) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + let number_answers = entities.len(); + + tracing::info!("important word found: {entities:?}"); + + let important_concepts = self.thinking_agent.prompt(format!("Can you extract the all keywords and important concepts from the original question. You can also add keywords related that aren't directly in the question. THE ANSWER SHOULD BE A SERIE OF SINGLE WORD OR A SINGLE CONCEPT IN QUOTATION MARKS. OUTPUT THEM EACH IN THEIR OWN QUOTATION MARKS. Here's the question: {question}")).await.unwrap_or("".to_string()); + + let concepts: Vec = string_extraction_regex + .captures_iter(&important_concepts) + .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string())) + .collect(); + + tracing::info!("important concepts are: {concepts:?}"); + + let concepts_embeddings: Vec> = self + .embedding_model + .embed(concepts, None) + .expect("Failed to get embedding") + .into_iter() + .map(|vec| vec.into_iter().map(|v| v as f64).collect()) + .collect::>(); + + let expanded_paths = join_all( + join_all(entities.into_iter().map(|entity| async { + let semantic_search = self.search(entity, Some(1)).await.unwrap_or_default(); + self.get_ids_from_search(semantic_search, create_relation_filter) + .await + .unwrap_or_default() + })) + .await + .into_iter() + .flatten() + .map(|start_node_id| async { + self.explore_node(question.clone(), concepts_embeddings.clone(), start_node_id) + .await + }), + ) + .await + .to_vec(); + + let answers_combined = join_all(expanded_paths.iter().map(|path| async { + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe full search of nodes:{}", path.clone() + ); + + self.thinking_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()) + })).await + .join("]\n["); + + let final_answer = if number_answers == 1 { + Ok(answers_combined) + } else { + let final_answer_prompt = format!( + "BASE YOUR ANSWER ONLY ON THE PARTIAL ANSWERS! YOU ANSWER ONLY THE ORIGINAL QUESTION SINCE THE MAIN USER DOESN'T CARE ABOUT PARTIAL ANSWERS. GIVE A FULL COMPLETE AND DETAILLED ANSWER. From the given partial answer from different starting points, can you provide a final single answer to the original question: {question}\n Here are the different partial answers:[{answers_combined}]" + ); + self.thinking_agent.prompt(final_answer_prompt).await + }; + + match final_answer { + Ok(answer) => Ok(Json(answer)), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } + } + + async fn create_path(&self, path_node: &PathNode, depth: usize) -> String { + if path_node.is_hidden || depth > 10 { + return "".to_string(); + } + const MAX_DISPLAY_LENGTH: usize = 8; + + let shorten_string = |word: String, max_length: usize| { + let splitted: Vec<&str> = word.split_whitespace().collect(); + if splitted.len() > max_length && !path_node.is_explored { + splitted[..max_length].join(" ") + "..." + } else { + splitted.join(" ") + } + }; + + let mut base = format!( + "{}{}-[{}]-{}{}{}({})\n", + " ".repeat(depth), + if path_node.is_inbound { "<" } else { "" }, + shorten_string(path_node.relation_name.clone(), MAX_DISPLAY_LENGTH), + if path_node.is_inbound { "" } else { ">" }, + if path_node.is_explored { "+" } else { "?" }, + shorten_string(path_node.entity_name.clone(), MAX_DISPLAY_LENGTH), + path_node.id.chars().take(6).collect::(), + ); + + let rest = join_all(path_node.nodes.iter().map(|node| async { + self.create_path(&node.lock().await.clone(), depth + 1) + .await + })) + .await + .join(""); + + base.push_str(&rest); + base + } + + async fn explore_node( + &self, + question: String, + concepts_embeddings: Vec>, + start_id: String, + ) -> String { + let explore_regex = Regex::new(r"(.*?)").unwrap(); + let answer_regex = Regex::new(r"(.*?)").unwrap(); + let end_regex = Regex::new(r"").unwrap(); + + let mut full_id_resolver: HashMap = HashMap::new(); + let mut expansions: HashMap> = HashMap::new(); + let mut seen_ids = HashSet::new(); + let mut seen_twice = HashSet::new(); + let mut explores = Vec::new(); + let mut partial_answers = Vec::new(); + let mut potential_explore = Vec::new(); + + let start_name = self + .get_name_of_id(start_id.clone()) + .await + .unwrap_or("No name".to_string()); + + let path_root: PathRef = Arc::new(Mutex::new(PathNode { + id: start_id.clone(), + relation_name: "Start".to_string(), + entity_name: start_name, + nodes: Vec::new(), + is_explored: false, + is_inbound: false, + is_hidden: false, + depth: 0, + })); + + potential_explore.push(path_root.clone()); + + let mut stack: Vec = vec![path_root.clone()]; + full_id_resolver.insert(start_id.chars().take(6).collect(), start_id); + + while let Some(current) = stack.pop() { + let current_node = current.lock().await; + let mut warning = false; + + if !seen_ids.insert(current_node.id.clone()) { + warning = true; + if !seen_twice.insert(current_node.id.clone()) { + continue; + } + } + + if partial_answers.len() >= 10 { + break; + } + + let data = format!( + "name: {}, description:{}", + current_node.entity_name, + self.get_description_of_id(current_node.id.clone()) + .await + .unwrap_or("No description".to_string()) + ); + + let (neighbor_nodes, hidden_options) = self + .expand_entity(current_node.id.clone(), concepts_embeddings.clone()) + .await + .unwrap_or_default(); + + drop(current_node); + let mut current_node_mut = current.lock().await; + current_node_mut.is_explored = true; + + for neighbor in neighbor_nodes.clone() { + let new_node = Arc::new(Mutex::new(PathNode { + nodes: Vec::new(), + is_explored: seen_ids.contains(&neighbor.id.clone()), + id: neighbor.id.clone(), + depth: neighbor.depth + 1, + ..neighbor + })); + full_id_resolver.insert(neighbor.id.chars().take(6).collect(), neighbor.id); + + if !warning && !neighbor.is_explored { + current_node_mut.nodes.push(new_node.clone()); + potential_explore.push(new_node); + } + } + + for (relation_name, hidden_path_nodes) in hidden_options { + let expand_id: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(6) + .map(char::from) + .collect(); + + let expansion_node = Arc::new(Mutex::new(PathNode { + id: expand_id.clone(), + relation_name: relation_name.clone(), + entity_name: format!( + "Expand {} more {relation_name}...", + hidden_path_nodes.len() + ), + nodes: Vec::new(), + is_explored: false, + is_inbound: false, + is_hidden: false, + depth: 0, + })); + if !warning { + current_node_mut.nodes.push(expansion_node.clone()); + } + + for neighbor in hidden_path_nodes { + let new_node = Arc::new(Mutex::new(PathNode { + nodes: Vec::new(), + is_explored: seen_ids.contains(&neighbor.id.clone()), + id: neighbor.id.clone(), + is_hidden: true, + ..neighbor + })); + full_id_resolver.insert(neighbor.id.chars().take(6).collect(), neighbor.id); + expansions + .entry(expand_id.clone()) + .or_default() + .push(new_node.clone()); + + if !warning && !neighbor.is_explored { + current_node_mut.nodes.push(new_node.clone()); + potential_explore.push(new_node); + } + } + } + + drop(current_node_mut); + + let named_path: String = { + let snapshot_of_path = path_root.lock().await; + self.create_path(&snapshot_of_path, 0).await + }; + + let warning_message = if warning { + "YOU HAVE ALREADY SEEN THIS NODE! YOU WON'T BE ABLE TO SEE IT AGAIN! SEEING IT AGAIN WON'T GIVE YOU MORE INFORMATION." + } else { + "" + }; + + let agent_prompt = format!( + "{warning_message}\nThe current explored node is:\n{data}\nThe original question is:\n{question}\nThe full exploration is:{named_path}" + ); + + let prompt_answer = self + .fast_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()); + + tracing::info!("Agent answer: {prompt_answer}"); + + for explore in explore_regex.captures_iter(&prompt_answer) { + if let Some(full_id) = full_id_resolver.get(&explore[1]) { + explores.push(full_id.clone()); + } + if let Some(hidden_nodes) = expansions.get(&explore[1]) { + let mut first = true; + for node in hidden_nodes { + let mut current_node_mut = node.lock().await; + current_node_mut.is_hidden = false; + if first { + explores.push(current_node_mut.id.clone()); + first = false; + } + drop(current_node_mut); + } + } + } + + if let Some(answer) = answer_regex.captures(&prompt_answer) { + partial_answers.push(answer[1].to_string()); + stack.push(current.clone()); + } + + if end_regex.captures(&prompt_answer).is_some() { + break; + } + + for potential_exploration_node in &potential_explore { + let node = potential_exploration_node.lock().await; + + let insert_node = explores.contains(&node.id); + explores.retain(|item| item != &node.id); + drop(node); + if insert_node { + // BFS + stack.insert(0, potential_exploration_node.clone()); + } + } + } + + let named_path: String = { + let snapshot_of_path = path_root.lock().await; + self.create_path(&snapshot_of_path, 0).await + }; + + let agent_prompt = format!( + "From the search that was done can you provide a final answer? Here's the original question: {question}\nThe partial_answers are:[{}]\nThe full path of nodes:{named_path}", + partial_answers.join("]\n[") + ); + + self.thinking_agent + .prompt(agent_prompt) + .await + .unwrap_or("Agent error".to_string()) + } + + async fn expand_entity( + &self, + base_id: String, + concepts_embeddings: Vec>, + ) -> Result<(Vec, HashMap>), Error> { + const IMPORTANCE_FACTOR: f64 = 0.9; + const MAX_SAME_RELATION: usize = 4; + const MAX_DIPLAYED_NEIGHBORS: usize = 15; + const EMBEDDING_DISTANCE_THRESHOLD: f64 = 1.5; + + let mut relation_count: HashMap = HashMap::new(); + let mut hidden_options: HashMap> = HashMap::new(); + + let mut relations = self + .extract_path_nodes_to_neighbors(base_id.clone(), 0, false) + .await + .unwrap_or_default(); + + relations.extend( + self.extract_path_nodes_to_neighbors(base_id.clone(), 0, true) + .await + .unwrap_or_default(), + ); + + let mut relations_with_scores: Vec<(f64, PathNode)> = relations + .into_iter() + .map(|path_node| { + let name_embedding = self.create_embedding(vec![&path_node.entity_name]); + let relation_embedding = self.create_embedding(vec![&path_node.relation_name]); + + // Compute all cosine distances to concept embeddings and take the best (lowest) + let min_name_distance = concepts_embeddings + .iter() + .map(|concept| self.cosine_distance(concept, &name_embedding)) + .fold(f64::INFINITY, |a, b| a.min(b)); + + // Compute all cosine distances to concept embeddings and take the best (lowest) + let min_rel_distance = concepts_embeddings + .iter() + .map(|concept| self.cosine_distance(concept, &relation_embedding)) + .fold(f64::INFINITY, |a, b| a.min(b)); + + // Many times, the same relation name is used. It now differentiate on the best entity name + let distance_score = if min_name_distance < min_rel_distance { + min_name_distance * IMPORTANCE_FACTOR + + min_rel_distance * (1.0 - IMPORTANCE_FACTOR) + } else { + min_name_distance * (1.0 - IMPORTANCE_FACTOR) + + min_rel_distance * IMPORTANCE_FACTOR + }; + + (distance_score, path_node) + }) + .collect(); + + relations_with_scores + .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Less)); + + let relations: Vec = relations_with_scores + .into_iter() + .filter(|(score, _)| score < &EMBEDDING_DISTANCE_THRESHOLD) + .filter_map(|(_, path_node)| { + let count = relation_count + .entry(path_node.relation_name.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + if *count <= MAX_SAME_RELATION { + Some(path_node) + } else { + hidden_options + .entry(path_node.relation_name.clone()) + .or_default() + .push(path_node); + None + } + }) + .take(MAX_DIPLAYED_NEIGHBORS) + .collect(); + + Ok((relations, hidden_options)) + } + + async fn extract_path_nodes_to_neighbors( + &self, + id: String, + depth: usize, + is_inbound: bool, + ) -> Result, StatusCode> { + const MAX_RELATIONS_CONSIDERED: usize = 100; + + let in_filter = relation::RelationFilter::default() + .to_(EntityFilter::default().id(prop_filter::value(id.clone()))); + let out_filter = relation::RelationFilter::default() + .from_(EntityFilter::default().id(prop_filter::value(id.clone()))); + + let relations = relation::find_many::>(&self.neo4j) + .filter(if is_inbound { in_filter } else { out_filter }) + .limit(MAX_RELATIONS_CONSIDERED) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let related_nodes = join_all(relations.into_iter().map(|result| async move { + let id = if is_inbound { + result.from.id + } else { + result.to.id + }; + + let entity_name = self + .get_name_of_id(id.to_string()) + .await + .unwrap_or("No name".to_string()); + + let relation_name = self + .get_name_of_id(result.relation_type.clone()) + .await + .unwrap_or("No name".to_string()); + + PathNode { + id: id.to_string(), + relation_name, + entity_name, + nodes: Vec::new(), + is_inbound, + is_explored: false, + is_hidden: false, + depth: depth + 1, + } + })) + .await + .to_vec(); + + Ok(related_nodes) + } + + async fn get_name_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + entity + .attributes + .name + .ok_or(StatusCode::INTERNAL_SERVER_ERROR) + } + + async fn get_description_of_id(&self, id: String) -> Result { + let entity = entity::find_one::>(&self.neo4j, &id) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(entity + .attributes + .description + .unwrap_or("No description".to_string())) + } + + async fn search( + &self, + query: String, + limit: Option, + ) -> Result, StatusCode> { + let embedding = self + .embedding_model + .embed(vec![&query], None) + .expect("Failed to get embedding") + .pop() + .expect("Embedding is empty") + .into_iter() + .map(|v| v as f64) + .collect::>(); + + let limit = limit.unwrap_or(10); + + let semantic_search_triples = triple::search(&self.neo4j, embedding) + .limit(limit) + .send() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .try_collect::>() + .await + .map_err(|e| { + tracing::error!("Error: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(semantic_search_triples) + } + + async fn get_ids_from_search( + &self, + search_triples: Vec, + create_relation_filter: impl Fn(SemanticSearchResult) -> RelationFilter, + ) -> Result, StatusCode> { + let mut seen_ids: HashSet = HashSet::new(); + let mut result_ids: Vec = Vec::new(); + + for semantic_search_triple in search_triples { + let filtered_for_types = relation::find_many::>(&self.neo4j) + .filter(create_relation_filter(semantic_search_triple)) + .send() + .await; + + //We only need to get the first relation since they would share the same entity id + if let Ok(stream) = filtered_for_types { + pin_mut!(stream); + if let Some(edge) = stream.try_next().await.ok().flatten() { + let id = edge.from.id; + if seen_ids.insert(id.clone()) { + result_ids.push(id); + } + } + } + } + Ok(result_ids) + } + + #[inline(always)] + fn create_embedding(&self, name: Vec<&str>) -> Vec { + self.embedding_model + .embed(name, None) + .expect("Failed to get embedding") + .pop() + .expect("No embedding found") + .into_iter() + .map(|v| v as f64) + .collect::>() + } + + fn cosine_distance(&self, a: &[f64], b: &[f64]) -> f64 { + let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let norm_a: f64 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f64 = b.iter().map(|x| x * x).sum::().sqrt(); + 1.0 - (dot_product / (norm_a * norm_b)) + } +} diff --git a/ai-search/src/main.rs b/ai-search/src/main.rs new file mode 100644 index 0000000..07dd7f2 --- /dev/null +++ b/ai-search/src/main.rs @@ -0,0 +1,97 @@ +use axum::{Router, extract::Json, http::StatusCode, routing::post}; +use clap::{Args, Parser}; +use grc20_core::neo4rs; +use std::sync::Arc; +use tracing_subscriber::{ + layer::SubscriberExt, + util::SubscriberInitExt, + {self}, +}; + +use crate::{automatic_search::AutomaticSearchAgent, full_ai_search::FullAISearchAgent}; + +mod automatic_search; +mod full_ai_search; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let args = AppArgs::parse(); + + let neo4j = neo4rs::Graph::new( + &args.neo4j_args.neo4j_uri, + &args.neo4j_args.neo4j_user, + &args.neo4j_args.neo4j_pass, + ) + .await?; + + let full_ai_search_agent = + Arc::new(FullAISearchAgent::new(neo4j.clone(), &args.gemini_api_key)); + let automatic_search_agent = Arc::new(AutomaticSearchAgent::new( + neo4j.clone(), + &args.gemini_api_key, + )); + + let app = Router::new() + .route("/question_ai", post(handler_full_ai_search)) + .route("/question", post(handler_automatic_search)) + .with_state((full_ai_search_agent, automatic_search_agent)); + + // run our app with hyper, listening globally on port 3000 + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + axum::serve(listener, app).await.unwrap(); + Ok(()) +} + +async fn handler_full_ai_search( + axum::extract::State((search_agent, _)): axum::extract::State<( + Arc, + Arc, + )>, + Json(question): Json, +) -> Result, StatusCode> { + tracing::info!("The question asked to the knowledge graph is: {question}"); + search_agent.natural_language_question(question).await +} + +async fn handler_automatic_search( + axum::extract::State((_, search_agent)): axum::extract::State<( + Arc, + Arc, + )>, + Json(question): Json, +) -> Result, StatusCode> { + tracing::info!("The question asked to the knowledge graph is: {question}"); + search_agent.natural_language_question(question).await +} + +#[derive(Debug, Parser)] +#[command(name = "stdout", version, about, arg_required_else_help = true)] +struct AppArgs { + #[clap(flatten)] + neo4j_args: Neo4jArgs, + #[arg(long)] + gemini_api_key: String, +} + +#[derive(Debug, Args)] +struct Neo4jArgs { + /// Neo4j database host + #[arg(long)] + neo4j_uri: String, + + /// Neo4j database user name + #[arg(long)] + neo4j_user: String, + + /// Neo4j database user password + #[arg(long)] + neo4j_pass: String, +} diff --git a/grc20-core/src/mapping/entity/mod.rs b/grc20-core/src/mapping/entity/mod.rs index 3a103f9..3286ff3 100644 --- a/grc20-core/src/mapping/entity/mod.rs +++ b/grc20-core/src/mapping/entity/mod.rs @@ -141,10 +141,7 @@ pub fn prefiltered_search( PrefilteredSemanticSearchQuery::new(neo4j, vector) } -pub fn search_from_restictions( - neo4j: &neo4rs::Graph, - vector: Vec, -) -> SearchWithTraversals { +pub fn traversal_search(neo4j: &neo4rs::Graph, vector: Vec) -> SearchWithTraversals { SearchWithTraversals::new(neo4j, vector) } diff --git a/grc20-core/src/mapping/entity/models.rs b/grc20-core/src/mapping/entity/models.rs index f1e1dfc..51dc307 100644 --- a/grc20-core/src/mapping/entity/models.rs +++ b/grc20-core/src/mapping/entity/models.rs @@ -233,7 +233,7 @@ impl Entity { } } -#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq)] pub struct SystemProperties { #[serde(rename = "82nP7aFmHJLbaPFszj2nbx")] // CREATED_AT_TIMESTAMP pub created_at: DateTime, @@ -243,6 +243,8 @@ pub struct SystemProperties { pub updated_at: DateTime, #[serde(rename = "7pXCVQDV9C7ozrXkpVg8RJ")] // UPDATED_AT_BLOCK pub updated_at_block: String, + #[serde(rename = "pagerank")] + pub page_rank: Option, } impl From for SystemProperties { @@ -252,6 +254,7 @@ impl From for SystemProperties { created_at_block: block.block_number.to_string(), updated_at: block.timestamp, updated_at_block: block.block_number.to_string(), + page_rank: None, } } } @@ -263,6 +266,7 @@ impl Default for SystemProperties { created_at_block: "0".to_string(), updated_at: Default::default(), updated_at_block: "0".to_string(), + page_rank: None, } } } diff --git a/grc20-core/src/mapping/entity/semantic_search.rs b/grc20-core/src/mapping/entity/semantic_search.rs index a0155a6..0aa1de1 100644 --- a/grc20-core/src/mapping/entity/semantic_search.rs +++ b/grc20-core/src/mapping/entity/semantic_search.rs @@ -123,8 +123,9 @@ impl QueryStream> for SemanticSearchQuery:\n{}", - query.compile() + "entity_node::FindManyQuery:::\n{}\nparams:{:?}", + query.compile(), + query.params() ); }; diff --git a/grc20-core/src/mapping/entity/utils.rs b/grc20-core/src/mapping/entity/utils.rs index faadd08..2aeee91 100644 --- a/grc20-core/src/mapping/entity/utils.rs +++ b/grc20-core/src/mapping/entity/utils.rs @@ -231,6 +231,9 @@ impl TraverseRelation { RelationDirection::To => { format!("({node_var_dest}) -[{rel_edge_var}:RELATION]-> ({node_var_curr})") } + RelationDirection::Both => { + format!("({node_var_curr}) -[{rel_edge_var}:RELATION]- ({node_var_dest})") + } }) // rename to change direction of relation .rename(Rename::new(NamePair::new( diff --git a/grc20-core/src/mapping/query_utils/relation_direction.rs b/grc20-core/src/mapping/query_utils/relation_direction.rs index d1cc35b..7a3d086 100644 --- a/grc20-core/src/mapping/query_utils/relation_direction.rs +++ b/grc20-core/src/mapping/query_utils/relation_direction.rs @@ -1,6 +1,7 @@ #[derive(Clone, Debug, Default)] pub enum RelationDirection { From, - #[default] To, + #[default] + Both, } diff --git a/grc20-core/src/mapping/relation/find_many.rs b/grc20-core/src/mapping/relation/find_many.rs index ea8f049..c3e94b0 100644 --- a/grc20-core/src/mapping/relation/find_many.rs +++ b/grc20-core/src/mapping/relation/find_many.rs @@ -123,8 +123,8 @@ impl FindManyQuery { .r#where(self.version.subquery("r")), ) .subquery(self.filter.subquery("r", "from", "to")) - .subquery("ORDER BY r.index") .limit(self.limit) + .subquery("ORDER BY r.index") .skip_opt(self.skip) } diff --git a/grc20-core/src/mapping/triple.rs b/grc20-core/src/mapping/triple.rs index dd3485e..0cc95d2 100644 --- a/grc20-core/src/mapping/triple.rs +++ b/grc20-core/src/mapping/triple.rs @@ -526,6 +526,7 @@ pub struct FindManyQuery { value_type: Option>, entity_id: Option>, space_id: Option>, + limit: usize, space_version: VersionFilter, } @@ -538,6 +539,7 @@ impl FindManyQuery { value_type: None, entity_id: None, space_id: None, + limit: 100, space_version: VersionFilter::default(), } } @@ -567,6 +569,11 @@ impl FindManyQuery { self } + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + pub fn space_version(mut self, space_version: impl Into) -> Self { self.space_version.version_mut(space_version.into()); self @@ -595,6 +602,7 @@ impl FindManyQuery { ) .r#where(self.space_version.subquery("r")), ) + .limit(self.limit) .subquery("RETURN n{.*, entity: e.id}") } } diff --git a/mcp-server/resources/get_entity_info_description.md b/mcp-server/resources/get_entity_info_description.md index fc17802..6786ba3 100644 --- a/mcp-server/resources/get_entity_info_description.md +++ b/mcp-server/resources/get_entity_info_description.md @@ -23,39 +23,21 @@ ToolResult> "name": "SF Mayor Lurie launching police task force to counter crime in core downtown areas", "relation_id": "8ESicJHiNJ28VGL5u34A5q", "relation_type": "Related spaces" - }, - { - "id": "6wAoNdGVbweKi2JRPZP4bX", - "name": "San Francisco Independent Film Festival", - "relation_id": "TH5Tu5Y5nacvREvAQRvcR2", - "relation_type": "Related spaces" - }, - { - "id": "8VCHYDURDStwuTCUBjWLQa", - "name": "Product Engineer at Geo", - "relation_id": "KPTqdNpCusxfM37KbKPX8w", - "relation_type": "Related spaces" - }, ... + } ], "outbound_relations": [ - { - "id": "7gzF671tq5JTZ13naG4tnr", - "name": "Space", - "relation_id": "WUZCXE1UGRtxdNQpGug8Tf", - "relation_type": "Types" - }, - { - "id": "CUoEazCD7EmzXPTFFY8gGY", - "name": "No name", - "relation_id": "5WeSkkE1XXvGJGmXj9VUQ8", - "relation_type": "Cover" - }, { "id": "D6Wy4bdtdoUrG3PDZceHr", "name": "City", "relation_id": "ARMj8fjJtdCwbtZa1f3jwe", "relation_type": "Types" - }, ... + }, + { + "id": "AhidiWYnQ8fAbHqfzdU74k", + "name": "Upcoming events", + "relation_id": "V1ikGW9riu7dAP8rMgZq3u", + "relation_type": "Blocks" + } ] } ``` diff --git a/mcp-server/resources/instructions.md b/mcp-server/resources/instructions.md index 9aeec45..e70d5be 100644 --- a/mcp-server/resources/instructions.md +++ b/mcp-server/resources/instructions.md @@ -5,181 +5,42 @@ You should use it for every request to get the informations for your answers sin The tools defined in the MCP server are made to be used in combination with each other. All except the most trivial requests will require the use of multiple tools. Here is an example: -User> Can you give me information about San Francisco? +User> Can you give me information about other restaurants near Saison? -ToolCall> search_entity({"query": "San Francisco"}) -ToolResult> -``` -{ - "entities": [ - { - "description": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District.", - "id": "3qayfdjYyPv1dAYf8gPL5r", - "name": "San Francisco" - }, - { - "description": null, - "id": "W5ZEpuy3Tij1XSXtJLruQ5", - "name": "SF Bay Area" - }, - { - "description": null, - "id": "RHoJT3hNVaw7m5fLLtZ8WQ", - "name": "California" - }, - { - "description": null, - "id": "Sh1qtjr4i92ZD6YGPeu5a2", - "name": "Abundant housing in San Francisco" - }, - { - "description": null, - "id": "UqLf9fTVKHkDs3LzP9zHpH", - "name": "Public safety in San Francisco" - }, - { - "description": null, - "id": "BeyiZ6oLqLMaSXiG41Yxtf", - "name": "City" - }, - { - "description": null, - "id": "D6Wy4bdtdoUrG3PDZceHr", - "name": "City" - }, - { - "description": null, - "id": "JWVrgUXmjS75PqNX2hry5q", - "name": "Clean streets in San Francisco" - }, - { - "description": null, - "id": "DcA2c7ooFTgEdtaRcaj7Z1", - "name": "Revitalizing downtown San Francisco" - }, - { - "description": null, - "id": "KWBLj9czHBBmYUT98rnxVM", - "name": "Location" - } - ] -} -``` -Let's get more info about San Francisco (id: 3qayfdjYyPv1dAYf8gPL5r) +Let's get precise information about Atelier Crenn. +Atelier Crenn id: "WFYTR1pxsZNjk8p6Z2CCg4" -ToolCall> get_entity_info("3qayfdjYyPv1dAYf8gPL5r") +ToolCall> get_entity_info("WFYTR1pxsZNjk8p6Z2CCg4") ToolResult> ``` { "all_attributes": [ - { - "attribute_name": "Description", - "attribute_value": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District." - }, { "attribute_name": "Name", - "attribute_value": "San Francisco" + "attribute_value": "Atelier Crenn" } ], - "id": "3qayfdjYyPv1dAYf8gPL5r", + "id": "WFYTR1pxsZNjk8p6Z2CCg4", "inbound_relations": [ { - "id": "NAMA1uDMzBQTvPYV9N92BV", - "name": "SF Mayor Lurie launching police task force to counter crime in core downtown areas", - "relation_id": "8ESicJHiNJ28VGL5u34A5q", - "relation_type": "Related spaces" - }, - { - "id": "6wAoNdGVbweKi2JRPZP4bX", - "name": "San Francisco Independent Film Festival", - "relation_id": "TH5Tu5Y5nacvREvAQRvcR2", - "relation_type": "Related spaces" - }, - { - "id": "8VCHYDURDStwuTCUBjWLQa", - "name": "Product Engineer at Geo", - "relation_id": "KPTqdNpCusxfM37KbKPX8w", - "relation_type": "Related spaces" - }, - { - "id": "NcQ3h9jeJSavVd8iFsUxvD", - "name": "Senior Civil Engineer @ Golden Gate Bridge, Highway & Transportation District", - "relation_id": "AqpNtJ3XxaY4fqRCyoXbdt", - "relation_type": "Cities" - }, - { - "id": "4ojV4dS1pV2tRnzXTpcMKJ", - "name": "Senior Plan Check Engineer (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "3AX4j43nywT5eBRV3s6AXi", - "relation_type": "Cities" - }, - { - "id": "QoakYWCuv85FVuYdSmonxr", - "name": "Senior Civil Engineer - Land Development (FT - Hybrid) @ CSG Consultants, Inc.", - "relation_id": "8GEF1i3LK4Z56THjE8dVku", - "relation_type": "Cities" - }, - { - "id": "JuV7jLoypebzLhkma6oZoU", - "name": "Lead Django Backend Engineer @ Textme Inc", - "relation_id": "46aBsQyBq15DimJ2i1DX4a", - "relation_type": "Cities" - }, - { - "id": "RTmcYhLVmmfgUn9L3D1J3y", - "name": "Chief Engineer @ Wyndham Hotels & Resorts", - "relation_id": "8uYxjzkkdjskDQAeTQomvc", - "relation_type": "Cities" + "id": "T6iKbwZ17iv4dRdR9Qw7qV", + "name": "Trending restaurants", + "relation_id": "Mwrn46KavwfWgNrFaWcB9j", + "relation_type": "Collection item" } ], "outbound_relations": [ { - "id": "CUoEazCD7EmzXPTFFY8gGY", + "id": "AxW1SQEvzvuKkPV6T19VDL", "name": "No name", - "relation_id": "5WeSkkE1XXvGJGmXj9VUQ8", + "relation_id": "7YHk6qYkNDaAtNb8GwmysF", "relation_type": "Cover" }, { - "id": "7gzF671tq5JTZ13naG4tnr", - "name": "Space", - "relation_id": "WUZCXE1UGRtxdNQpGug8Tf", + "id": "A9QizqoXSqjfPUBjLoPJa2", + "name": "Restaurant", + "relation_id": "Jfmby78N4BCseZinBmdVov", "relation_type": "Types" - }, - { - "id": "D6Wy4bdtdoUrG3PDZceHr", - "name": "City", - "relation_id": "ARMj8fjJtdCwbtZa1f3jwe", - "relation_type": "Types" - }, - { - "id": "AhidiWYnQ8fAbHqfzdU74k", - "name": "Upcoming events", - "relation_id": "V1ikGW9riu7dAP8rMgZq3u", - "relation_type": "Blocks" - }, - { - "id": "T6iKbwZ17iv4dRdR9Qw7qV", - "name": "Trending restaurants", - "relation_id": "CvGXCmGXE7ofsgZeWad28p", - "relation_type": "Blocks" - }, - { - "id": "X18WRE36mjwQ7gu3LKaLJS", - "name": "Neighborhoods", - "relation_id": "Uxpsee9LoTgJqMFfAQyJP6", - "relation_type": "Blocks" - }, - { - "id": "HeC2pygci2tnvjTt5aEnBV", - "name": "Top goals", - "relation_id": "5WMTAzCnZH9Bsevou9GQ3K", - "relation_type": "Blocks" - }, - { - "id": "5YtYFsnWq1jupvh5AjM2ni", - "name": "Culture", - "relation_id": "5TmxfepRr1THMRkGWenj5G", - "relation_type": "Tabs" } ] } diff --git a/mcp-server/resources/search_entity_description.md b/mcp-server/resources/search_entity_description.md index ae3a46a..3276ecf 100644 --- a/mcp-server/resources/search_entity_description.md +++ b/mcp-server/resources/search_entity_description.md @@ -1,81 +1,38 @@ This request allows you to get Entities from a name/description search and traversal from that query by using relation name. -Example Query: Find employees that works at The Graph. - -ToolCall> -``` -search_entity( - { - "query": "The Graph", - "traversal_filter": { - "relation_type": "Works at", - "direction": "From" - } - } -) -``` - -ToolResult> -``` -{ - "entities": [ - { - "description": "Founder & CEO of Geo. Cofounder of The Graph, Edge & Node, House of Web3. Building a vibrant decentralized future.", - "id": "9HsfMWYHr9suYdMrtssqiX", - "name": "Yaniv Tal" - }, - { - "description": "Developer Relations Engineer", - "id": "22MGz47c9WHtRiHuSEPkcG", - "name": "Kevin Jones" - }, - { - "description": "Description will go here", - "id": "JYTfEcdmdjiNzBg469gE83", - "name": "Pedro Diogo" - } - ] -} -``` - Example Query: Find all the articles written by employees that works at The Graph. ToolCall> ``` search_entity( - { - "query": "The Graph", - "traversal_filter": { - "relation_type": "Works at", - "direction": "From", - "traversal_filter": { - "relation_type": "Author", - "direction": "From" - } - } + { + "query": "The graph", + "traversal_filter": + { + "relation_type": "works at", + "traversal_filter": + { + "relation_type": "authors" + } } + } ) ``` ToolResult> ``` -{ - "entities": [ - { - "description": "A fresh look at what web3 is and what the missing pieces have been for making it a reality.", - "id": "XYo6aR3VqFQSEcf6AeTikW", - "name": "Knowledge graphs are web3" - }, - { - "description": "A new standard is here for structuring knowledge. GRC-20 will reshape how we make applications composable and redefine web3.", - "id": "5FkVvS4mTz6Ge7wHkAUMRk", - "name": "Introducing GRC-20: A knowledge graph standard for web3" - }, - { - "description": "How do you know what is true? Who do you trust? Everybody has a point of view, but no one is an authority. As humanity we need a way to aggregate our knowledge into something we can trust. We need a system.", - "id": "5WHP8BuoCdSiqtfy87SYWG", - "name": "Governing public knowledge" - } - ] -} +[ + 0: { description: "Founder & CEO of Geo. Cofounder of The Graph, Edge & Node, House of Web3. Building a vibrant decentralized future." + id: "9HsfMWYHr9suYdMrtssqiX" + name: "Yaniv Tal" + } + 1: { description: "Developer Relations Engineer" + id: "22MGz47c9WHtRiHuSEPkcG" + name: "Kevin Jones" + } + 2: { description: "Description will go here" + id: "JYTfEcdmdjiNzBg469gE83" + name: "Pedro Diogo" + } +] ``` diff --git a/mcp-server/resources/search_entity_using_ids_description.md b/mcp-server/resources/search_entity_using_ids_description.md index 740f54a..9ed0cb0 100644 --- a/mcp-server/resources/search_entity_using_ids_description.md +++ b/mcp-server/resources/search_entity_using_ids_description.md @@ -9,30 +9,24 @@ ToolCall> search_entity_using_ids({ "query": "The Graph", "traversal_filter": { - "relation_type": "U1uCAzXsRSTP4vFwo1JwJG", - "direction": "From" + "relation_type": "U1uCAzXsRSTP4vFwo1JwJG" } }) ``` ToolResult> ``` -{ - "entities": [ - { - "description": "Founder & CEO of Geo. Cofounder of The Graph, Edge & Node, House of Web3. Building a vibrant decentralized future.", - "id": "9HsfMWYHr9suYdMrtssqiX", - "name": "Yaniv Tal" - }, - { - "description": "Developer Relations Engineer", - "id": "22MGz47c9WHtRiHuSEPkcG", - "name": "Kevin Jones" - }, - { - "description": "Description will go here", - "id": "JYTfEcdmdjiNzBg469gE83", - "name": "Pedro Diogo" - } - ] -} +[ + 0: { description: "Founder & CEO of Geo. Cofounder of The Graph, Edge & Node, House of Web3. Building a vibrant decentralized future." + id: "9HsfMWYHr9suYdMrtssqiX" + name: "Yaniv Tal" + } + 1: { description: "Developer Relations Engineer" + id: "22MGz47c9WHtRiHuSEPkcG" + name: "Kevin Jones" + } + 2: { description: "Description will go here" + id: "JYTfEcdmdjiNzBg469gE83" + name: "Pedro Diogo" + } +] ``` diff --git a/mcp-server/resources/search_space_description.md b/mcp-server/resources/search_space_description.md new file mode 100644 index 0000000..8f442ba --- /dev/null +++ b/mcp-server/resources/search_space_description.md @@ -0,0 +1,26 @@ +This request allows you to find a Space from it's name or description. The spaces are where the attributes and relations are and may be useful to specify when querying entities and relations. + +ToolCall> +``` +search_space("San Francisco") +``` + +ToolResult> +``` +[ + [ + { + "attribute_name": "Description", + "attribute_value": "A vibrant city known for its iconic Golden Gate Bridge, steep rolling hills, historic cable cars, and a rich cultural tapestry including diverse neighborhoods like the Castro and the Mission District.", + "entity_id": "3qayfdjYyPv1dAYf8gPL5r" + }, + { + "attribute_name": "Name", + "attribute_value": "San Francisco", + "entity_id": "3qayfdjYyPv1dAYf8gPL5r" + } + ] +] +``` + +Eventually, space will be used to narrow research or help format result diff --git a/mcp-server/src/input_types.rs b/mcp-server/src/input_types.rs index 1464cb1..139ed06 100644 --- a/mcp-server/src/input_types.rs +++ b/mcp-server/src/input_types.rs @@ -7,9 +7,10 @@ pub struct SearchTraversalInputFilter { #[derive(Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)] pub struct TraversalFilter { - pub direction: RelationDirection, pub relation_type: String, #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub traversal_filter: Option>, } diff --git a/mcp-server/src/main.rs b/mcp-server/src/main.rs index f872489..64d64be 100644 --- a/mcp-server/src/main.rs +++ b/mcp-server/src/main.rs @@ -109,6 +109,7 @@ impl KnowledgeGraph { RawResource::new(uri, name.to_string()).no_annotation() } + // Tools that are available in the mcp server #[tool(description = include_str!("../resources/search_type_description.md"))] async fn search_types( &self, @@ -172,15 +173,7 @@ impl KnowledgeGraph { ) -> Result { tracing::info!("SearchTraversalFilter query: {:?}", search_traversal_filter); - let embedding = self - .embedding_model - .embed(vec![&search_traversal_filter.query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); + let embedding = self.create_embedding(vec![&search_traversal_filter.query]); let traversal_filters: Vec> = match search_traversal_filter.traversal_filter { @@ -188,8 +181,11 @@ impl KnowledgeGraph { join_all(traversal_filter_input.into_iter().map(|filter| async move { Ok(TraverseRelation::default() .direction(match filter.direction { - input_types::RelationDirection::From => RelationDirection::From, - input_types::RelationDirection::To => RelationDirection::To, + Some(direction) => match direction { + input_types::RelationDirection::From => RelationDirection::From, + input_types::RelationDirection::To => RelationDirection::To, + }, + None => RelationDirection::Both, }) .relation_type_id(prop_filter::value(filter.relation_type))) })) @@ -202,10 +198,7 @@ impl KnowledgeGraph { let results_search = traversal_filters .into_iter() .fold( - entity::search_from_restictions::>( - &self.neo4j, - embedding.clone(), - ), + entity::traversal_search::>(&self.neo4j, embedding.clone()), |query, result_traversal_filter: Result<_, McpError>| match result_traversal_filter { Ok(traversal_filter) => { @@ -258,15 +251,7 @@ impl KnowledgeGraph { ) -> Result { tracing::info!("SearchTraversalFilter query: {:?}", search_traversal_filter); - let embedding = self - .embedding_model - .embed(vec![&search_traversal_filter.query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); + let embedding = self.create_embedding(vec![&search_traversal_filter.query]); let start_filters = Instant::now(); @@ -287,8 +272,11 @@ impl KnowledgeGraph { Ok(TraverseRelation::default() .direction(match filter.direction { - input_types::RelationDirection::From => RelationDirection::From, - input_types::RelationDirection::To => RelationDirection::To, + Some(direction) => match direction { + input_types::RelationDirection::From => RelationDirection::From, + input_types::RelationDirection::To => RelationDirection::To, + }, + None => RelationDirection::Both, }) .relation_type_id(prop_filter::value_in(relation_ids))) })) @@ -303,10 +291,7 @@ impl KnowledgeGraph { let results_search = traversal_filters .into_iter() .fold( - entity::search_from_restictions::>( - &self.neo4j, - embedding.clone(), - ), + entity::traversal_search::>(&self.neo4j, embedding.clone()), |query, result_traversal_filter: Result<_, McpError>| match result_traversal_filter { Ok(traversal_filter) => { @@ -432,8 +417,8 @@ impl KnowledgeGraph { .into_iter() .map(|result| async move { json!({ - "relation_id": result.id, - "relation_type": self.get_name_of_id(result.relation_type.clone()).await.unwrap_or(result.relation_type.to_string()), + "relation_id": result.relation_type, + "relation_type": self.get_name_of_id(result.relation_type).await.unwrap_or("No relation type".to_string()), "id": if is_inbound {result.from.id.clone()} else {result.to.id.clone()}, "name": self.get_name_of_id(if is_inbound {result.from.id.clone()} else {result.to.id.clone()}).await.unwrap_or("No name".to_string()), }) @@ -502,21 +487,25 @@ impl KnowledgeGraph { )) } + #[inline(always)] + fn create_embedding(&self, name: Vec<&str>) -> Vec { + self.embedding_model + .embed(name, None) + .expect("Failed to get embedding") + .pop() + .expect("No embedding found") + .into_iter() + .map(|v| v as f64) + .collect::>() + } + async fn query_search( &self, query: String, limit: Option, filter: EntityFilter, ) -> Result, McpError> { - let embedding = self - .embedding_model - .embed(vec![&query], None) - .expect("Failed to get embedding") - .pop() - .expect("Embedding is empty") - .into_iter() - .map(|v| v as f64) - .collect::>(); + let embedding = self.create_embedding(vec![&query]); let limit = limit.unwrap_or(10); let semantic_search_triples = @@ -528,7 +517,7 @@ impl KnowledgeGraph { .map_err(|e| { tracing::error!("Error: {e:?}"); McpError::internal_error( - "search_types_failed", + "query_search_failed", Some(json!({ "error": e.to_string() })), ) })? @@ -537,7 +526,7 @@ impl KnowledgeGraph { .map_err(|e| { tracing::error!("Error changing to vec: {e:?}"); McpError::internal_error( - "search_types_failed", + "query_search_failed", Some(json!({ "error": e.to_string() })), ) })?; @@ -587,7 +576,9 @@ impl KnowledgeGraph { .ok_or_else(|| { McpError::internal_error("entity_name_not_found", Some(json!({ "id": id }))) })?; - Ok(entity.attributes.name.unwrap_or("No name".to_string())) + entity.attributes.name.ok_or_else(|| { + McpError::internal_error("entity_name_not_found", Some(json!({ "id": id }))) + }) } } @@ -618,53 +609,6 @@ impl ServerHandler for KnowledgeGraph { } Ok(self.get_info()) } - - //TODO: make prompt examples to use on data - async fn list_prompts( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListPromptsResult { - next_cursor: None, - prompts: vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required argument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )], - }) - } - - async fn get_prompt( - &self, - GetPromptRequestParam { name, arguments }: GetPromptRequestParam, - _: RequestContext, - ) -> Result { - match name.as_str() { - "example_prompt" => { - let message = arguments - .and_then(|json| json.get("message")?.as_str().map(|s| s.to_string())) - .ok_or_else(|| { - McpError::invalid_params("No message provided to example_prompt", None) - })?; - - let prompt = - format!("This is an example prompt with your message here: '{message}'"); - Ok(GetPromptResult { - description: None, - messages: vec![PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::text(prompt), - }], - }) - } - _ => Err(McpError::invalid_params("prompt not found", None)), - } - } } #[derive(Debug, Parser)]