refactor: add search engine module and update main.rs

This commit is contained in:
Alex Wellnitz 2024-02-11 16:52:39 +01:00
parent f070b45fac
commit 43bfb82a62
10 changed files with 98 additions and 695 deletions

View File

@ -1 +0,0 @@
pub mod tokenizer;

View File

@ -1,127 +0,0 @@
use std::collections::HashSet;
use regex::Regex;
use unicode_segmentation::UnicodeSegmentation;
pub struct Tokenizer {
text: String,
stopwords: HashSet<String>,
punctuation: HashSet<String>,
}
impl Tokenizer {
pub fn new(text: &str, stopwords: Vec<String>, punctuation: Option<Vec<String>>) -> Self {
Self {
text: text.to_owned(),
stopwords: stopwords
.iter()
.map(|s| s.to_owned())
.collect::<HashSet<String>>(),
punctuation: punctuation
.unwrap_or(
vec![
"!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", ";", ".", "/",
":", ",", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|",
"}", "~", "-",
]
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>(),
)
.iter()
.map(|s| s.to_string())
.collect::<HashSet<String>>(),
}
}
// Split text into words
pub fn split_into_words(&self) -> Vec<String> {
self.text
.split_word_bounds()
.filter_map(|w| {
process_word(
w,
&get_special_char_regex(),
&self.stopwords,
&self.punctuation,
)
})
.collect::<Vec<String>>()
}
pub fn split_into_sentences(&self) -> Vec<String> {
let special_char_regex = get_special_char_regex();
get_sentence_space_regex()
.replace_all(&self.text, ".")
.unicode_sentences()
.map(|s| {
s.split_word_bounds()
.filter_map(|w| {
process_word(w, &special_char_regex, &self.stopwords, &self.punctuation)
})
.collect::<Vec<String>>()
.join(" ")
})
.collect::<Vec<String>>()
}
pub fn split_into_paragraphs(&self) -> Vec<String> {
get_newline_regex()
.split(&self.text)
.filter_map(|s| {
if s.trim().is_empty() {
return None;
}
Some(
s.unicode_sentences()
.map(|s| {
s.split_word_bounds()
.filter_map(|w| {
process_word(
w,
&get_special_char_regex(),
&self.stopwords,
&self.punctuation,
)
})
.collect::<Vec<String>>()
.join(" ")
})
.collect::<Vec<String>>()
.join(" "),
)
})
.collect::<Vec<String>>()
}
}
fn process_word(
w: &str,
special_char_regex: &Regex,
stopwords: &HashSet<String>,
punctuation: &HashSet<String>,
) -> Option<String> {
let word = special_char_regex.replace_all(w.trim(), "").to_lowercase();
if word.is_empty()
|| (word.graphemes(true).count() == 1) && punctuation.contains(&word)
|| stopwords.contains(&word)
{
return None;
}
Some(word)
}
fn get_special_char_regex() -> Regex {
Regex::new(r"('s|,|\.)").unwrap()
}
fn get_sentence_space_regex() -> Regex {
Regex::new(r"^([\.!?])[\n\t\r]").unwrap()
}
fn get_newline_regex() -> Regex {
Regex::new(r"(\r|\n|\r\n)").unwrap()
}

View File

@ -1,3 +1 @@
pub mod types;
pub mod search;
pub mod analyze;
pub mod search;

View File

@ -1,7 +1,13 @@
use rustysearch::search::Rustysearch;
use rustysearch::search::engine::SearchEngine;
fn main() {
println!("Hello, world!");
let search = Rustysearch::new("/tmp/rustysearch");
search.setup();
let mut engine = SearchEngine::new(1.5, 0.75);
engine.index("https://www.rust-lang.org/", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety.");
engine.index("https://en.wikipedia.org/wiki/Rust_(programming_language)", "Rust is a multi-paradigm system programming language focused on safety, especially safe concurrency.");
let query = "Rust programming language threads";
let results = engine.search(query);
for (url, score) in results {
println!("{}: {}", url, score);
}
}

View File

@ -1,310 +0,0 @@
use std::{
cmp::min,
collections::{HashMap, HashSet},
fs,
io::{Read, Write},
path::Path,
};
use tempfile::NamedTempFile;
use serde_json::{json, Value};
use crate::{analyze::tokenizer::Tokenizer, types::Stats};
pub struct Rustysearch {
base_directory: String,
index_path: String,
docs_path: String,
stats_path: String,
}
impl Rustysearch {
/// **Sets up the object & the data directory**
///
/// Requires a ``base_directory`` parameter, which specifies the parent
/// directory the index/document/stats data will be kept in.
///
pub fn new(path: &str) -> Self {
Self {
base_directory: path.to_string(),
index_path: format!("{}/index", path),
docs_path: format!("{}/docs", path),
stats_path: format!("{}/stats.json", path),
}
}
/// **Handles the creation of the various data directories**
///
/// If the paths do not exist, it will create them. As a side effect, you
/// must have read/write access to the location you're trying to create
/// the data at.
///
pub fn setup(&self) {
// Create the base directory
if !Path::new(&self.base_directory).exists() {
fs::create_dir(&self.base_directory).expect("Unable to create base directory");
}
// Create the index directory
if !Path::new(&self.index_path).exists() {
fs::create_dir(&self.index_path).expect("Unable to create index directory");
}
// Create the docs directory
if !Path::new(&self.docs_path).exists() {
fs::create_dir(&self.docs_path).expect("Unable to create docs directory");
}
}
/// **Reads the index-wide stats**
///
/// If the stats do not exist, it makes returns data with the current
/// version of ``rustysearch`` & zero docs (used in scoring).
///
pub fn read_stats(&self) -> std::io::Result<Stats> {
let stats: Stats;
if !Path::new(&self.stats_path).exists() {
stats = Stats {
version: String::from("0.1.0"),
total_docs: 0,
};
} else {
// Read the stats file
let stats_json = fs::read_to_string(&self.stats_path).expect("Unable to read stats");
stats = serde_json::from_str(&stats_json).unwrap();
}
Ok(stats)
}
/// **Writes the index-wide stats**
///
/// Takes a ``new_stats`` parameter, which should be a dictionary of
/// stat data. Example stat data::
///
/// {
/// 'version': '1.0.0',
/// 'total_docs': 25,
/// }
///
pub fn write_stats(&self, new_stats: Stats) -> std::io::Result<()> {
// Write new_stats as json to stats_path
let new_stats_json = serde_json::to_string(&new_stats).unwrap();
fs::write(&self.stats_path, new_stats_json)?;
Ok(())
}
/// **Increments the total number of documents the index is aware of**
///
/// This is important for scoring reasons & is typically called as part
/// of the indexing process.
///
pub fn increment_total_docs(&self) {
let mut current_stats = self.read_stats().unwrap();
current_stats.total_docs += 1;
self.write_stats(current_stats).unwrap();
}
/// Returns the total number of documents the index is aware of
///
pub fn get_total_docs(&self) -> i32 {
let stats = self.read_stats().unwrap();
return stats.total_docs;
}
/// Given a string (``blob``) of text, this will return a Vector of tokens.
///
pub fn make_tokens(&self, blob: &str) -> Vec<String> {
let tokenizer = Tokenizer::new(blob, vec![], None);
let tokens = tokenizer.split_into_words();
return tokens;
}
/// **Converts a iterable of ``tokens`` into n-grams**
///
/// This assumes front grams (all grams made starting from the left side
/// of the token).
///
/// Optionally accepts a ``min_gram`` parameter, which takes an integer &
/// controls the minimum gram length. Default is ``3``.
///
/// Optionally accepts a ``max_gram`` parameter, which takes an integer &
/// controls the maximum gram length. Default is ``6``.
///
pub fn make_ngrams(
&self,
tokens: Vec<String>,
min_gram: usize,
max_gram: usize,
) -> HashMap<String, Vec<usize>> {
let mut terms: HashMap<String, Vec<usize>> = HashMap::new();
for (position, token) in tokens.iter().enumerate() {
for window_length in min_gram..min(max_gram + 1, token.len() + 1) {
// Assuming "front" grams.
let gram = &token[..window_length];
terms
.entry(gram.to_string())
.or_insert(Vec::new())
.push(position);
}
}
return terms;
}
/// Given a ``term``, hashes it & returns a string of the first N letters
///
/// Optionally accepts a ``length`` parameter, which takes an integer &
/// controls how much of the hash is returned. Default is ``6``.
///
/// This is usefully when writing files to the file system, as it helps
/// us keep from putting too many files in a given directory (~32K max
/// with the default).
///
pub fn hash_name(&self, term: &str, length: usize) -> String {
// Make sure it's ASCII.
let term = term.to_ascii_lowercase();
// We hash & slice the term to get a small-ish number of fields
// and good distribution between them.
let hash = md5::compute(&term);
let hashed = format!("{:x}", hash);
// Cut string after length characters
let hashed = &hashed[..length];
return hashed.to_string();
}
/// Given a ``term``, creates a segment filename based on the hash of the term.
///
/// Returns the full path to the segment.
///
pub fn make_segment_name(&self, term: &str) -> String {
let term = &self.hash_name(term, 6);
let index_file_name = format!("{}.index", term);
let segment_path = Path::new(&self.index_path).join(index_file_name);
let segment_path = segment_path.to_str().unwrap().to_string();
fs::write(&segment_path, "").expect("Unable to create segment file");
return segment_path;
}
/// Given a ``line`` from the segment file, this returns the term & its info.
///
/// The term info is stored as serialized JSON. The default separator
/// between the term & info is the ``\t`` character, which would never
/// appear in a term due to the way tokenization is done.
///
pub fn parse_record(&self, line: &str) -> (String, String) {
let mut parts = line.trim().split("\t");
let term = parts.next().unwrap().to_string();
let info = parts.next().unwrap().to_string();
(term, info)
}
/// Given a ``term`` and a dict of ``term_info``, creates a line for
/// writing to the segment file.
///
pub fn make_record(&self, term: &str, term_info: &Value) -> String {
format!("{}\t{}\n", term, json!(term_info).to_string())
}
/// Takes existing ``orig_info`` & ``new_info`` dicts & combines them
/// intelligently.
///
/// Used for updating term_info within the segments.
///
pub fn update_term_info(&self, orig_info: &mut Value, new_info: &Value) -> Value {
for (doc_id, positions) in new_info.as_object().unwrap().iter() {
if !orig_info.as_object().unwrap().contains_key(doc_id) {
orig_info[doc_id] = positions.clone();
} else {
let mut orig_positions: HashSet<_> = orig_info[doc_id]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap().to_string())
.collect();
let new_positions: HashSet<_> = positions
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap().to_string())
.collect();
orig_positions.extend(new_positions);
orig_info[doc_id] = Value::Array(
orig_positions
.iter()
.map(|v| Value::String(v.clone()))
.collect(),
);
}
}
return orig_info.to_owned();
}
/// Writes out new index data to disk.
///
/// Takes a ``term`` string & ``term_info`` dict. It will
/// rewrite the segment in alphabetical order, adding in the data
/// where appropriate.
///
/// Optionally takes an ``update`` parameter, which is a boolean &
/// determines whether the provided ``term_info`` should overwrite or
/// update the data in the segment. Default is ``False`` (overwrite).
///
pub fn save_segment(&self, term: &str, term_info: &Value, update: bool) -> bool {
let seg_name = &self.make_segment_name(term);
let mut new_seg_file = NamedTempFile::new().unwrap();
let mut written = false;
if !Path::new(&seg_name).exists() {
fs::write(&seg_name, "").unwrap();
}
let mut seg_file = fs::OpenOptions::new().read(true).open(&seg_name).unwrap();
let mut buf = String::new();
seg_file.read_to_string(&mut buf).unwrap();
for line in buf.lines() {
let (seg_term, seg_term_info) = self.parse_record(line);
if !written && seg_term > term.to_string() {
let new_line = self.make_record(term, term_info);
new_seg_file.write_all(new_line.as_bytes()).unwrap();
written = true;
} else if seg_term == term {
if update {
let new_info = self.update_term_info(&mut json!(seg_term_info), term_info);
let new_line = self.make_record(term, &new_info);
new_seg_file.write_all(new_line.as_bytes()).unwrap();
} else {
let line = self.make_record(term, term_info);
new_seg_file.write_all(line.as_bytes()).unwrap();
}
written = true;
}
new_seg_file.write_all(line.as_bytes()).unwrap();
}
if !written {
let line = self.make_record(term, term_info);
new_seg_file.write_all(line.as_bytes()).unwrap();
}
fs::rename(&new_seg_file.path(), &seg_name).unwrap();
new_seg_file.flush().unwrap();
// new_seg_file.close().unwrap();
return true;
}
}

86
src/search/engine.rs Normal file
View File

@ -0,0 +1,86 @@
use std::collections::HashMap;
use std::f64;
pub fn update_url_scores(old: &mut HashMap<String, f64>, new: &HashMap<String, f64>) {
for (url, score) in new {
old.entry(url.to_string()).and_modify(|e| *e += score).or_insert(*score);
}
}
pub fn normalize_string(input_string: &str) -> String {
let string_without_punc: String = input_string.chars().filter(|&c| !c.is_ascii_punctuation()).collect();
let string_without_double_spaces: String = string_without_punc.split_whitespace().collect::<Vec<&str>>().join(" ");
string_without_double_spaces.to_lowercase()
}
pub struct SearchEngine {
index: HashMap<String, HashMap<String, i32>>,
documents: HashMap<String, String>,
k1: f64,
b: f64,
}
impl SearchEngine {
pub fn new(k1: f64, b: f64) -> SearchEngine {
SearchEngine {
index: HashMap::new(),
documents: HashMap::new(),
k1,
b,
}
}
pub fn posts(&self) -> Vec<String> {
self.documents.keys().cloned().collect()
}
pub fn number_of_documents(&self) -> usize {
self.documents.len()
}
pub fn avdl(&self) -> f64 {
let total_length: usize = self.documents.values().map(|d| d.len()).sum();
total_length as f64 / self.documents.len() as f64
}
pub fn idf(&self, kw: &str) -> f64 {
let n = self.number_of_documents() as f64;
let n_kw = self.get_urls(kw).len() as f64;
((n - n_kw + 0.5) / (n_kw + 0.5) + 1.0).ln()
}
pub fn bm25(&self, kw: &str) -> HashMap<String, f64> {
let mut result = HashMap::new();
let idf_score = self.idf(kw);
let avdl = self.avdl();
for (url, freq) in self.get_urls(kw) {
let numerator = freq as f64 * (self.k1 + 1.0);
let denominator = freq as f64 + self.k1 * (1.0 - self.b + self.b * self.documents.get(&url).unwrap().len() as f64 / avdl);
result.insert(url.to_string(), idf_score * numerator / denominator);
}
result
}
pub fn search(&mut self, query: &str) -> HashMap<String, f64> {
let keywords = normalize_string(query).split_whitespace().map(|s| s.to_string()).collect::<Vec<String>>();
let mut url_scores: HashMap<String, f64> = HashMap::new();
for kw in keywords {
let kw_urls_score = self.bm25(&kw);
update_url_scores(&mut url_scores, &kw_urls_score);
}
url_scores
}
pub fn index(&mut self, url: &str, content: &str) {
self.documents.insert(url.to_string(), content.to_string());
let words = normalize_string(content).split_whitespace().map(|s| s.to_string()).collect::<Vec<String>>();
for word in words {
*self.index.entry(word).or_insert(HashMap::new()).entry(url.to_string()).or_insert(0) += 1;
}
}
pub fn get_urls(&self, keyword: &str) -> HashMap<String, i32> {
let keyword = normalize_string(keyword);
self.index.get(&keyword).cloned().unwrap_or(HashMap::new())
}
}

1
src/search/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod engine;

View File

@ -1,7 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct Stats{
pub version: String,
pub total_docs: i32,
}

View File

@ -1,190 +0,0 @@
#[cfg(test)]
mod tests {
use rustysearch::{search::Rustysearch, types::Stats};
use serde_json::json;
#[test]
fn test_write_new_stats() {
let stats = Stats {
version: String::from("0.1.0"),
total_docs: 0,
};
assert_eq!(stats.version, "0.1.0");
assert_eq!(stats.total_docs, 0);
let tmp_path = "/tmp/rustysearch_writenewstats";
let search = Rustysearch::new(&tmp_path);
search.setup();
search.write_stats(stats).unwrap();
}
#[test]
fn test_read_stats() {
let tmp_path = "/tmp/rustysearch_readstats";
let search = Rustysearch::new(&tmp_path);
search.setup();
clean_stats(tmp_path);
let stats = search.read_stats().unwrap();
assert_eq!(stats.version, "0.1.0");
assert_eq!(stats.total_docs, 0);
}
#[test]
fn test_increment_total_docs() {
let tmp_path = "/tmp/rustysearch_incrementtotaldocs";
let search = Rustysearch::new(&tmp_path);
search.setup();
clean_stats(tmp_path);
let stats = search.read_stats().unwrap();
assert_eq!(stats.total_docs, 0);
search.increment_total_docs();
let stats = search.read_stats().unwrap();
assert_eq!(stats.total_docs, 1);
}
#[test]
fn test_get_total_docs() {
let tmp_path = "/tmp/rustysearch_gettotaldocs";
let search = Rustysearch::new(&tmp_path);
search.setup();
clean_stats(tmp_path);
let stats = search.read_stats().unwrap();
assert_eq!(stats.total_docs, 0);
search.increment_total_docs();
let stats = search.read_stats().unwrap();
assert_eq!(stats.total_docs, 1);
let total_docs = search.get_total_docs();
assert_eq!(total_docs, 1);
}
#[test]
fn test_make_ngrams() {
let search = Rustysearch::new("/tmp/rustysearch_makengrams");
search.setup();
let tokens = vec!["hello".to_string(), "world".to_string()];
let terms = search.make_ngrams(tokens, 3, 6);
assert_eq!(terms["hel"].len(), 1);
}
#[test]
fn test_hash_name() {
let search = Rustysearch::new("/tmp/rustysearch_hashname");
search.setup();
let hash = search.hash_name("hello", 6);
assert_eq!(hash, "5d4140");
}
#[test]
fn test_make_segment_name() {
let search = Rustysearch::new("/tmp/rustysearch_makesegmentname");
search.setup();
let segment_name = search.make_segment_name("hello");
assert_eq!(
segment_name,
"/tmp/rustysearch_makesegmentname/index/5d4140.index"
);
}
#[test]
fn test_parse_record() {
let search = Rustysearch::new("/tmp/rustysearch_parserecord");
search.setup();
let line = "my_term\t{\"frequency\": 100}";
let (term, info) = search.parse_record(line);
assert_eq!(term, "my_term");
assert_eq!(info, "{\"frequency\": 100}");
}
#[test]
fn test_make_tokens() {
let search = Rustysearch::new("/tmp/rustysearch");
let tokens = search.make_tokens("Hello, world!");
assert_eq!(tokens, vec!["hello", "world"]);
}
#[test]
fn test_make_record() {
let search = Rustysearch::new("/tmp/rustysearch");
let term = "hello world";
let term_info = json!({
"frequency": 100,
"idf": 1.5,
});
let record = search.make_record(term, &term_info);
assert_eq!(record, "hello world\t{\"frequency\":100,\"idf\":1.5}\n");
}
#[test]
fn test_update_term_info() {
let mut orig_info = json!({
"doc1": ["1", "2"],
"doc2": ["3", "4"]
});
let new_info = json!({
"doc3": ["1", "2"]
});
let expected_result = json!({
"doc1": ["1", "2"],
"doc2": ["3", "4"],
"doc3": ["1", "2"]
});
let search = Rustysearch::new("/tmp/rustysearch");
let result = search.update_term_info(&mut orig_info, &new_info);
assert_eq!(result, expected_result);
}
#[test]
fn test_save_segment() {
let search = Rustysearch::new("/tmp/rustysearch_save_segment");
search.setup();
let term = "rust";
let term_info = json!({"doc1": ["1", "5"], "doc2": ["2", "6"]});
// Test saving a new segment
let result = search.save_segment(term, &term_info, false);
assert_eq!(result, true);
// Test updating an existing segment
let new_term_info = json!({"doc1": ["1", "5", "10"], "doc3": ["3", "7"]});
let result = search.save_segment(term, &new_term_info, true);
assert_eq!(result, true);
// Test overwriting an existing segment
let result = search.save_segment(term, &term_info, false);
assert_eq!(result, true);
}
// Helper function to clean up the stats file
fn clean_stats(tmp_path: &str) {
let search = Rustysearch::new(tmp_path);
search.setup();
let new_stats = Stats {
version: String::from("0.1.0"),
total_docs: 0,
};
search.write_stats(new_stats).unwrap();
}
}

View File

@ -1,53 +0,0 @@
#[cfg(test)]
mod tests {
use rustysearch::analyze::tokenizer::Tokenizer;
#[test]
fn test_split_into_words() {
let text = "The quick brown fox jumps over the lazy dog.";
let stopwords = vec!["the".to_string(), "over".to_string()];
let tokenizer = Tokenizer::new(text, stopwords, None);
let words = tokenizer.split_into_words();
assert_eq!(
words,
vec![
"quick".to_string(),
"brown".to_string(),
"fox".to_string(),
"jumps".to_string(),
"lazy".to_string(),
"dog".to_string(),
]
);
}
#[test]
fn test_split_into_sentences() {
let text = "The quick brown fox jumps over the lazy dog. The end.";
let stopwords = vec!["the".to_string(), "over".to_string()];
let tokenizer = Tokenizer::new(text, stopwords, None);
let sentences = tokenizer.split_into_sentences();
assert_eq!(
sentences,
vec![
"quick brown fox jumps lazy dog".to_string(),
"end".to_string(),
]
);
}
#[test]
fn test_split_into_paragraphs() {
let text = "The quick brown fox jumps over the lazy dog.\n\nThe end.";
let stopwords = vec!["the".to_string(), "over".to_string()];
let tokenizer = Tokenizer::new(text, stopwords, None);
let paragraphs = tokenizer.split_into_paragraphs();
assert_eq!(
paragraphs,
vec![
"quick brown fox jumps lazy dog".to_string(),
"end".to_string(),
]
);
}
}