Merge pull request #1 from alexohneander/refactor
refactor: add search engine module and update main.rs
This commit is contained in:
commit
77b684ac42
@ -1 +0,0 @@
|
||||
pub mod tokenizer;
|
@ -1,127 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use regex::Regex;
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
pub struct Tokenizer {
|
||||
text: String,
|
||||
stopwords: HashSet<String>,
|
||||
punctuation: HashSet<String>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new(text: &str, stopwords: Vec<String>, punctuation: Option<Vec<String>>) -> Self {
|
||||
Self {
|
||||
text: text.to_owned(),
|
||||
stopwords: stopwords
|
||||
.iter()
|
||||
.map(|s| s.to_owned())
|
||||
.collect::<HashSet<String>>(),
|
||||
punctuation: punctuation
|
||||
.unwrap_or(
|
||||
vec![
|
||||
"!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", ";", ".", "/",
|
||||
":", ",", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|",
|
||||
"}", "~", "-",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>(),
|
||||
)
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<HashSet<String>>(),
|
||||
}
|
||||
}
|
||||
|
||||
// Split text into words
|
||||
pub fn split_into_words(&self) -> Vec<String> {
|
||||
self.text
|
||||
.split_word_bounds()
|
||||
.filter_map(|w| {
|
||||
process_word(
|
||||
w,
|
||||
&get_special_char_regex(),
|
||||
&self.stopwords,
|
||||
&self.punctuation,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
|
||||
pub fn split_into_sentences(&self) -> Vec<String> {
|
||||
let special_char_regex = get_special_char_regex();
|
||||
get_sentence_space_regex()
|
||||
.replace_all(&self.text, ".")
|
||||
.unicode_sentences()
|
||||
.map(|s| {
|
||||
s.split_word_bounds()
|
||||
.filter_map(|w| {
|
||||
process_word(w, &special_char_regex, &self.stopwords, &self.punctuation)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
|
||||
pub fn split_into_paragraphs(&self) -> Vec<String> {
|
||||
get_newline_regex()
|
||||
.split(&self.text)
|
||||
.filter_map(|s| {
|
||||
if s.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
s.unicode_sentences()
|
||||
.map(|s| {
|
||||
s.split_word_bounds()
|
||||
.filter_map(|w| {
|
||||
process_word(
|
||||
w,
|
||||
&get_special_char_regex(),
|
||||
&self.stopwords,
|
||||
&self.punctuation,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
}
|
||||
|
||||
fn process_word(
|
||||
w: &str,
|
||||
special_char_regex: &Regex,
|
||||
stopwords: &HashSet<String>,
|
||||
punctuation: &HashSet<String>,
|
||||
) -> Option<String> {
|
||||
let word = special_char_regex.replace_all(w.trim(), "").to_lowercase();
|
||||
|
||||
if word.is_empty()
|
||||
|| (word.graphemes(true).count() == 1) && punctuation.contains(&word)
|
||||
|| stopwords.contains(&word)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(word)
|
||||
}
|
||||
|
||||
fn get_special_char_regex() -> Regex {
|
||||
Regex::new(r"('s|,|\.)").unwrap()
|
||||
}
|
||||
|
||||
fn get_sentence_space_regex() -> Regex {
|
||||
Regex::new(r"^([\.!?])[\n\t\r]").unwrap()
|
||||
}
|
||||
|
||||
fn get_newline_regex() -> Regex {
|
||||
Regex::new(r"(\r|\n|\r\n)").unwrap()
|
||||
}
|
@ -1,3 +1 @@
|
||||
pub mod types;
|
||||
pub mod search;
|
||||
pub mod analyze;
|
||||
pub mod search;
|
14
src/main.rs
14
src/main.rs
@ -1,7 +1,13 @@
|
||||
use rustysearch::search::Rustysearch;
|
||||
use rustysearch::search::engine::SearchEngine;
|
||||
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
let search = Rustysearch::new("/tmp/rustysearch");
|
||||
search.setup();
|
||||
let mut engine = SearchEngine::new(1.5, 0.75);
|
||||
engine.index("https://www.rust-lang.org/", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety.");
|
||||
engine.index("https://en.wikipedia.org/wiki/Rust_(programming_language)", "Rust is a multi-paradigm system programming language focused on safety, especially safe concurrency.");
|
||||
|
||||
let query = "Rust programming language threads";
|
||||
let results = engine.search(query);
|
||||
for (url, score) in results {
|
||||
println!("{}: {}", url, score);
|
||||
}
|
||||
}
|
310
src/search.rs
310
src/search.rs
@ -1,310 +0,0 @@
|
||||
use std::{
|
||||
cmp::min,
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
io::{Read, Write},
|
||||
path::Path,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{analyze::tokenizer::Tokenizer, types::Stats};
|
||||
|
||||
pub struct Rustysearch {
|
||||
base_directory: String,
|
||||
index_path: String,
|
||||
docs_path: String,
|
||||
stats_path: String,
|
||||
}
|
||||
|
||||
impl Rustysearch {
|
||||
/// **Sets up the object & the data directory**
|
||||
///
|
||||
/// Requires a ``base_directory`` parameter, which specifies the parent
|
||||
/// directory the index/document/stats data will be kept in.
|
||||
///
|
||||
pub fn new(path: &str) -> Self {
|
||||
Self {
|
||||
base_directory: path.to_string(),
|
||||
index_path: format!("{}/index", path),
|
||||
docs_path: format!("{}/docs", path),
|
||||
stats_path: format!("{}/stats.json", path),
|
||||
}
|
||||
}
|
||||
|
||||
/// **Handles the creation of the various data directories**
|
||||
///
|
||||
/// If the paths do not exist, it will create them. As a side effect, you
|
||||
/// must have read/write access to the location you're trying to create
|
||||
/// the data at.
|
||||
///
|
||||
pub fn setup(&self) {
|
||||
// Create the base directory
|
||||
if !Path::new(&self.base_directory).exists() {
|
||||
fs::create_dir(&self.base_directory).expect("Unable to create base directory");
|
||||
}
|
||||
// Create the index directory
|
||||
if !Path::new(&self.index_path).exists() {
|
||||
fs::create_dir(&self.index_path).expect("Unable to create index directory");
|
||||
}
|
||||
// Create the docs directory
|
||||
if !Path::new(&self.docs_path).exists() {
|
||||
fs::create_dir(&self.docs_path).expect("Unable to create docs directory");
|
||||
}
|
||||
}
|
||||
|
||||
/// **Reads the index-wide stats**
|
||||
///
|
||||
/// If the stats do not exist, it makes returns data with the current
|
||||
/// version of ``rustysearch`` & zero docs (used in scoring).
|
||||
///
|
||||
pub fn read_stats(&self) -> std::io::Result<Stats> {
|
||||
let stats: Stats;
|
||||
|
||||
if !Path::new(&self.stats_path).exists() {
|
||||
stats = Stats {
|
||||
version: String::from("0.1.0"),
|
||||
total_docs: 0,
|
||||
};
|
||||
} else {
|
||||
// Read the stats file
|
||||
let stats_json = fs::read_to_string(&self.stats_path).expect("Unable to read stats");
|
||||
stats = serde_json::from_str(&stats_json).unwrap();
|
||||
}
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// **Writes the index-wide stats**
|
||||
///
|
||||
/// Takes a ``new_stats`` parameter, which should be a dictionary of
|
||||
/// stat data. Example stat data::
|
||||
///
|
||||
/// {
|
||||
/// 'version': '1.0.0',
|
||||
/// 'total_docs': 25,
|
||||
/// }
|
||||
///
|
||||
pub fn write_stats(&self, new_stats: Stats) -> std::io::Result<()> {
|
||||
// Write new_stats as json to stats_path
|
||||
let new_stats_json = serde_json::to_string(&new_stats).unwrap();
|
||||
fs::write(&self.stats_path, new_stats_json)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// **Increments the total number of documents the index is aware of**
|
||||
///
|
||||
/// This is important for scoring reasons & is typically called as part
|
||||
/// of the indexing process.
|
||||
///
|
||||
pub fn increment_total_docs(&self) {
|
||||
let mut current_stats = self.read_stats().unwrap();
|
||||
current_stats.total_docs += 1;
|
||||
self.write_stats(current_stats).unwrap();
|
||||
}
|
||||
|
||||
/// Returns the total number of documents the index is aware of
|
||||
///
|
||||
pub fn get_total_docs(&self) -> i32 {
|
||||
let stats = self.read_stats().unwrap();
|
||||
return stats.total_docs;
|
||||
}
|
||||
|
||||
/// Given a string (``blob``) of text, this will return a Vector of tokens.
|
||||
///
|
||||
pub fn make_tokens(&self, blob: &str) -> Vec<String> {
|
||||
let tokenizer = Tokenizer::new(blob, vec![], None);
|
||||
let tokens = tokenizer.split_into_words();
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/// **Converts a iterable of ``tokens`` into n-grams**
|
||||
///
|
||||
/// This assumes front grams (all grams made starting from the left side
|
||||
/// of the token).
|
||||
///
|
||||
/// Optionally accepts a ``min_gram`` parameter, which takes an integer &
|
||||
/// controls the minimum gram length. Default is ``3``.
|
||||
///
|
||||
/// Optionally accepts a ``max_gram`` parameter, which takes an integer &
|
||||
/// controls the maximum gram length. Default is ``6``.
|
||||
///
|
||||
pub fn make_ngrams(
|
||||
&self,
|
||||
tokens: Vec<String>,
|
||||
min_gram: usize,
|
||||
max_gram: usize,
|
||||
) -> HashMap<String, Vec<usize>> {
|
||||
let mut terms: HashMap<String, Vec<usize>> = HashMap::new();
|
||||
|
||||
for (position, token) in tokens.iter().enumerate() {
|
||||
for window_length in min_gram..min(max_gram + 1, token.len() + 1) {
|
||||
// Assuming "front" grams.
|
||||
let gram = &token[..window_length];
|
||||
terms
|
||||
.entry(gram.to_string())
|
||||
.or_insert(Vec::new())
|
||||
.push(position);
|
||||
}
|
||||
}
|
||||
|
||||
return terms;
|
||||
}
|
||||
|
||||
/// Given a ``term``, hashes it & returns a string of the first N letters
|
||||
///
|
||||
/// Optionally accepts a ``length`` parameter, which takes an integer &
|
||||
/// controls how much of the hash is returned. Default is ``6``.
|
||||
///
|
||||
/// This is usefully when writing files to the file system, as it helps
|
||||
/// us keep from putting too many files in a given directory (~32K max
|
||||
/// with the default).
|
||||
///
|
||||
pub fn hash_name(&self, term: &str, length: usize) -> String {
|
||||
// Make sure it's ASCII.
|
||||
let term = term.to_ascii_lowercase();
|
||||
|
||||
// We hash & slice the term to get a small-ish number of fields
|
||||
// and good distribution between them.
|
||||
let hash = md5::compute(&term);
|
||||
let hashed = format!("{:x}", hash);
|
||||
|
||||
// Cut string after length characters
|
||||
let hashed = &hashed[..length];
|
||||
|
||||
return hashed.to_string();
|
||||
}
|
||||
|
||||
/// Given a ``term``, creates a segment filename based on the hash of the term.
|
||||
///
|
||||
/// Returns the full path to the segment.
|
||||
///
|
||||
pub fn make_segment_name(&self, term: &str) -> String {
|
||||
let term = &self.hash_name(term, 6);
|
||||
|
||||
let index_file_name = format!("{}.index", term);
|
||||
let segment_path = Path::new(&self.index_path).join(index_file_name);
|
||||
let segment_path = segment_path.to_str().unwrap().to_string();
|
||||
|
||||
fs::write(&segment_path, "").expect("Unable to create segment file");
|
||||
|
||||
return segment_path;
|
||||
}
|
||||
|
||||
/// Given a ``line`` from the segment file, this returns the term & its info.
|
||||
///
|
||||
/// The term info is stored as serialized JSON. The default separator
|
||||
/// between the term & info is the ``\t`` character, which would never
|
||||
/// appear in a term due to the way tokenization is done.
|
||||
///
|
||||
pub fn parse_record(&self, line: &str) -> (String, String) {
|
||||
let mut parts = line.trim().split("\t");
|
||||
let term = parts.next().unwrap().to_string();
|
||||
let info = parts.next().unwrap().to_string();
|
||||
(term, info)
|
||||
}
|
||||
|
||||
/// Given a ``term`` and a dict of ``term_info``, creates a line for
|
||||
/// writing to the segment file.
|
||||
///
|
||||
pub fn make_record(&self, term: &str, term_info: &Value) -> String {
|
||||
format!("{}\t{}\n", term, json!(term_info).to_string())
|
||||
}
|
||||
|
||||
/// Takes existing ``orig_info`` & ``new_info`` dicts & combines them
|
||||
/// intelligently.
|
||||
///
|
||||
/// Used for updating term_info within the segments.
|
||||
///
|
||||
pub fn update_term_info(&self, orig_info: &mut Value, new_info: &Value) -> Value {
|
||||
for (doc_id, positions) in new_info.as_object().unwrap().iter() {
|
||||
if !orig_info.as_object().unwrap().contains_key(doc_id) {
|
||||
orig_info[doc_id] = positions.clone();
|
||||
} else {
|
||||
let mut orig_positions: HashSet<_> = orig_info[doc_id]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| v.as_str().unwrap().to_string())
|
||||
.collect();
|
||||
let new_positions: HashSet<_> = positions
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| v.as_str().unwrap().to_string())
|
||||
.collect();
|
||||
|
||||
orig_positions.extend(new_positions);
|
||||
|
||||
orig_info[doc_id] = Value::Array(
|
||||
orig_positions
|
||||
.iter()
|
||||
.map(|v| Value::String(v.clone()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return orig_info.to_owned();
|
||||
}
|
||||
|
||||
/// Writes out new index data to disk.
|
||||
///
|
||||
/// Takes a ``term`` string & ``term_info`` dict. It will
|
||||
/// rewrite the segment in alphabetical order, adding in the data
|
||||
/// where appropriate.
|
||||
///
|
||||
/// Optionally takes an ``update`` parameter, which is a boolean &
|
||||
/// determines whether the provided ``term_info`` should overwrite or
|
||||
/// update the data in the segment. Default is ``False`` (overwrite).
|
||||
///
|
||||
pub fn save_segment(&self, term: &str, term_info: &Value, update: bool) -> bool {
|
||||
let seg_name = &self.make_segment_name(term);
|
||||
let mut new_seg_file = NamedTempFile::new().unwrap();
|
||||
let mut written = false;
|
||||
|
||||
if !Path::new(&seg_name).exists() {
|
||||
fs::write(&seg_name, "").unwrap();
|
||||
}
|
||||
|
||||
let mut seg_file = fs::OpenOptions::new().read(true).open(&seg_name).unwrap();
|
||||
let mut buf = String::new();
|
||||
seg_file.read_to_string(&mut buf).unwrap();
|
||||
|
||||
for line in buf.lines() {
|
||||
let (seg_term, seg_term_info) = self.parse_record(line);
|
||||
|
||||
if !written && seg_term > term.to_string() {
|
||||
let new_line = self.make_record(term, term_info);
|
||||
new_seg_file.write_all(new_line.as_bytes()).unwrap();
|
||||
written = true;
|
||||
} else if seg_term == term {
|
||||
if update {
|
||||
let new_info = self.update_term_info(&mut json!(seg_term_info), term_info);
|
||||
let new_line = self.make_record(term, &new_info);
|
||||
new_seg_file.write_all(new_line.as_bytes()).unwrap();
|
||||
} else {
|
||||
let line = self.make_record(term, term_info);
|
||||
new_seg_file.write_all(line.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
written = true;
|
||||
}
|
||||
|
||||
new_seg_file.write_all(line.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
if !written {
|
||||
let line = self.make_record(term, term_info);
|
||||
new_seg_file.write_all(line.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
fs::rename(&new_seg_file.path(), &seg_name).unwrap();
|
||||
|
||||
new_seg_file.flush().unwrap();
|
||||
// new_seg_file.close().unwrap();
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
86
src/search/engine.rs
Normal file
86
src/search/engine.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use std::collections::HashMap;
|
||||
use std::f64;
|
||||
|
||||
pub fn update_url_scores(old: &mut HashMap<String, f64>, new: &HashMap<String, f64>) {
|
||||
for (url, score) in new {
|
||||
old.entry(url.to_string()).and_modify(|e| *e += score).or_insert(*score);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn normalize_string(input_string: &str) -> String {
|
||||
let string_without_punc: String = input_string.chars().filter(|&c| !c.is_ascii_punctuation()).collect();
|
||||
let string_without_double_spaces: String = string_without_punc.split_whitespace().collect::<Vec<&str>>().join(" ");
|
||||
string_without_double_spaces.to_lowercase()
|
||||
}
|
||||
|
||||
pub struct SearchEngine {
|
||||
index: HashMap<String, HashMap<String, i32>>,
|
||||
documents: HashMap<String, String>,
|
||||
k1: f64,
|
||||
b: f64,
|
||||
}
|
||||
|
||||
impl SearchEngine {
|
||||
pub fn new(k1: f64, b: f64) -> SearchEngine {
|
||||
SearchEngine {
|
||||
index: HashMap::new(),
|
||||
documents: HashMap::new(),
|
||||
k1,
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn posts(&self) -> Vec<String> {
|
||||
self.documents.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn number_of_documents(&self) -> usize {
|
||||
self.documents.len()
|
||||
}
|
||||
|
||||
pub fn avdl(&self) -> f64 {
|
||||
let total_length: usize = self.documents.values().map(|d| d.len()).sum();
|
||||
total_length as f64 / self.documents.len() as f64
|
||||
}
|
||||
|
||||
pub fn idf(&self, kw: &str) -> f64 {
|
||||
let n = self.number_of_documents() as f64;
|
||||
let n_kw = self.get_urls(kw).len() as f64;
|
||||
((n - n_kw + 0.5) / (n_kw + 0.5) + 1.0).ln()
|
||||
}
|
||||
|
||||
pub fn bm25(&self, kw: &str) -> HashMap<String, f64> {
|
||||
let mut result = HashMap::new();
|
||||
let idf_score = self.idf(kw);
|
||||
let avdl = self.avdl();
|
||||
for (url, freq) in self.get_urls(kw) {
|
||||
let numerator = freq as f64 * (self.k1 + 1.0);
|
||||
let denominator = freq as f64 + self.k1 * (1.0 - self.b + self.b * self.documents.get(&url).unwrap().len() as f64 / avdl);
|
||||
result.insert(url.to_string(), idf_score * numerator / denominator);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn search(&mut self, query: &str) -> HashMap<String, f64> {
|
||||
let keywords = normalize_string(query).split_whitespace().map(|s| s.to_string()).collect::<Vec<String>>();
|
||||
let mut url_scores: HashMap<String, f64> = HashMap::new();
|
||||
for kw in keywords {
|
||||
let kw_urls_score = self.bm25(&kw);
|
||||
update_url_scores(&mut url_scores, &kw_urls_score);
|
||||
}
|
||||
url_scores
|
||||
}
|
||||
|
||||
pub fn index(&mut self, url: &str, content: &str) {
|
||||
self.documents.insert(url.to_string(), content.to_string());
|
||||
let words = normalize_string(content).split_whitespace().map(|s| s.to_string()).collect::<Vec<String>>();
|
||||
for word in words {
|
||||
*self.index.entry(word).or_insert(HashMap::new()).entry(url.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_urls(&self, keyword: &str) -> HashMap<String, i32> {
|
||||
let keyword = normalize_string(keyword);
|
||||
self.index.get(&keyword).cloned().unwrap_or(HashMap::new())
|
||||
}
|
||||
}
|
1
src/search/mod.rs
Normal file
1
src/search/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod engine;
|
@ -1,7 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Stats{
|
||||
pub version: String,
|
||||
pub total_docs: i32,
|
||||
}
|
@ -1,190 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rustysearch::{search::Rustysearch, types::Stats};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_write_new_stats() {
|
||||
let stats = Stats {
|
||||
version: String::from("0.1.0"),
|
||||
total_docs: 0,
|
||||
};
|
||||
|
||||
assert_eq!(stats.version, "0.1.0");
|
||||
assert_eq!(stats.total_docs, 0);
|
||||
|
||||
let tmp_path = "/tmp/rustysearch_writenewstats";
|
||||
let search = Rustysearch::new(&tmp_path);
|
||||
search.setup();
|
||||
|
||||
search.write_stats(stats).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_stats() {
|
||||
let tmp_path = "/tmp/rustysearch_readstats";
|
||||
let search = Rustysearch::new(&tmp_path);
|
||||
search.setup();
|
||||
|
||||
clean_stats(tmp_path);
|
||||
|
||||
let stats = search.read_stats().unwrap();
|
||||
assert_eq!(stats.version, "0.1.0");
|
||||
assert_eq!(stats.total_docs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_increment_total_docs() {
|
||||
let tmp_path = "/tmp/rustysearch_incrementtotaldocs";
|
||||
let search = Rustysearch::new(&tmp_path);
|
||||
search.setup();
|
||||
|
||||
clean_stats(tmp_path);
|
||||
|
||||
let stats = search.read_stats().unwrap();
|
||||
assert_eq!(stats.total_docs, 0);
|
||||
|
||||
search.increment_total_docs();
|
||||
let stats = search.read_stats().unwrap();
|
||||
assert_eq!(stats.total_docs, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_total_docs() {
|
||||
let tmp_path = "/tmp/rustysearch_gettotaldocs";
|
||||
let search = Rustysearch::new(&tmp_path);
|
||||
search.setup();
|
||||
|
||||
clean_stats(tmp_path);
|
||||
|
||||
let stats = search.read_stats().unwrap();
|
||||
assert_eq!(stats.total_docs, 0);
|
||||
|
||||
search.increment_total_docs();
|
||||
let stats = search.read_stats().unwrap();
|
||||
assert_eq!(stats.total_docs, 1);
|
||||
|
||||
let total_docs = search.get_total_docs();
|
||||
assert_eq!(total_docs, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_ngrams() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch_makengrams");
|
||||
search.setup();
|
||||
|
||||
let tokens = vec!["hello".to_string(), "world".to_string()];
|
||||
let terms = search.make_ngrams(tokens, 3, 6);
|
||||
|
||||
assert_eq!(terms["hel"].len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_name() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch_hashname");
|
||||
search.setup();
|
||||
|
||||
let hash = search.hash_name("hello", 6);
|
||||
assert_eq!(hash, "5d4140");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_segment_name() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch_makesegmentname");
|
||||
search.setup();
|
||||
|
||||
let segment_name = search.make_segment_name("hello");
|
||||
assert_eq!(
|
||||
segment_name,
|
||||
"/tmp/rustysearch_makesegmentname/index/5d4140.index"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_record() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch_parserecord");
|
||||
search.setup();
|
||||
|
||||
let line = "my_term\t{\"frequency\": 100}";
|
||||
let (term, info) = search.parse_record(line);
|
||||
|
||||
assert_eq!(term, "my_term");
|
||||
assert_eq!(info, "{\"frequency\": 100}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_tokens() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch");
|
||||
let tokens = search.make_tokens("Hello, world!");
|
||||
assert_eq!(tokens, vec!["hello", "world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_make_record() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch");
|
||||
let term = "hello world";
|
||||
let term_info = json!({
|
||||
"frequency": 100,
|
||||
"idf": 1.5,
|
||||
});
|
||||
|
||||
let record = search.make_record(term, &term_info);
|
||||
assert_eq!(record, "hello world\t{\"frequency\":100,\"idf\":1.5}\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_term_info() {
|
||||
let mut orig_info = json!({
|
||||
"doc1": ["1", "2"],
|
||||
"doc2": ["3", "4"]
|
||||
});
|
||||
|
||||
let new_info = json!({
|
||||
"doc3": ["1", "2"]
|
||||
});
|
||||
|
||||
let expected_result = json!({
|
||||
"doc1": ["1", "2"],
|
||||
"doc2": ["3", "4"],
|
||||
"doc3": ["1", "2"]
|
||||
});
|
||||
let search = Rustysearch::new("/tmp/rustysearch");
|
||||
let result = search.update_term_info(&mut orig_info, &new_info);
|
||||
|
||||
assert_eq!(result, expected_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_segment() {
|
||||
let search = Rustysearch::new("/tmp/rustysearch_save_segment");
|
||||
search.setup();
|
||||
|
||||
let term = "rust";
|
||||
let term_info = json!({"doc1": ["1", "5"], "doc2": ["2", "6"]});
|
||||
|
||||
// Test saving a new segment
|
||||
let result = search.save_segment(term, &term_info, false);
|
||||
assert_eq!(result, true);
|
||||
|
||||
// Test updating an existing segment
|
||||
let new_term_info = json!({"doc1": ["1", "5", "10"], "doc3": ["3", "7"]});
|
||||
let result = search.save_segment(term, &new_term_info, true);
|
||||
assert_eq!(result, true);
|
||||
|
||||
// Test overwriting an existing segment
|
||||
let result = search.save_segment(term, &term_info, false);
|
||||
assert_eq!(result, true);
|
||||
}
|
||||
|
||||
// Helper function to clean up the stats file
|
||||
fn clean_stats(tmp_path: &str) {
|
||||
let search = Rustysearch::new(tmp_path);
|
||||
search.setup();
|
||||
|
||||
let new_stats = Stats {
|
||||
version: String::from("0.1.0"),
|
||||
total_docs: 0,
|
||||
};
|
||||
search.write_stats(new_stats).unwrap();
|
||||
}
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rustysearch::analyze::tokenizer::Tokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_split_into_words() {
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
let stopwords = vec!["the".to_string(), "over".to_string()];
|
||||
let tokenizer = Tokenizer::new(text, stopwords, None);
|
||||
let words = tokenizer.split_into_words();
|
||||
assert_eq!(
|
||||
words,
|
||||
vec![
|
||||
"quick".to_string(),
|
||||
"brown".to_string(),
|
||||
"fox".to_string(),
|
||||
"jumps".to_string(),
|
||||
"lazy".to_string(),
|
||||
"dog".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_into_sentences() {
|
||||
let text = "The quick brown fox jumps over the lazy dog. The end.";
|
||||
let stopwords = vec!["the".to_string(), "over".to_string()];
|
||||
let tokenizer = Tokenizer::new(text, stopwords, None);
|
||||
let sentences = tokenizer.split_into_sentences();
|
||||
assert_eq!(
|
||||
sentences,
|
||||
vec![
|
||||
"quick brown fox jumps lazy dog".to_string(),
|
||||
"end".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_into_paragraphs() {
|
||||
let text = "The quick brown fox jumps over the lazy dog.\n\nThe end.";
|
||||
let stopwords = vec!["the".to_string(), "over".to_string()];
|
||||
let tokenizer = Tokenizer::new(text, stopwords, None);
|
||||
let paragraphs = tokenizer.split_into_paragraphs();
|
||||
assert_eq!(
|
||||
paragraphs,
|
||||
vec![
|
||||
"quick brown fox jumps lazy dog".to_string(),
|
||||
"end".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user