Skip to content

Commit

Permalink
separated plugin engine and lua api, added providers config to result…
Browse files Browse the repository at this point in the history
…s ctx for tera, added some more motds
  • Loading branch information
frc4533-lincoln committed Nov 1, 2024
1 parent 11aedc5 commit 2643203
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 191 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate serde;
extern crate searched_parser;

pub mod config;
pub mod lua_api;
pub mod lua_support;

use std::collections::HashMap;

Expand Down
158 changes: 11 additions & 147 deletions src/lua_api.rs → src/lua_support/api.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
use core::error;
use std::{
collections::{HashMap, VecDeque},
error::Error,
fs::{read_dir, File},
io::Read,
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};

use fend_core::Context;
use mlua::prelude::*;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::Client;
use scraper::{node::Element, Html, Selector};
use tokio::{
sync::{oneshot, watch, Mutex},
task::{spawn_local, LocalSet},
sync::Mutex,
};
use url::Url;

use crate::{config::ProvidersConfig, Query};
use crate::Query;

impl LuaUserData for Query {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
Expand All @@ -29,7 +22,7 @@ impl LuaUserData for Query {
}

/// Lua wrapper for [url::Url]
struct UrlWrapper(Url);
pub struct UrlWrapper(Url);
impl UrlWrapper {
fn parse(_: &Lua, url: String) -> LuaResult<Self> {
Url::parse(&url).map(|x| UrlWrapper(x)).into_lua_err()
Expand Down Expand Up @@ -102,7 +95,7 @@ impl LuaUserData for UrlWrapper {
}

/// Lua wrapper for [reqwest::Client]
struct ClientWrapper(Client);
pub struct ClientWrapper(pub Client);
impl ClientWrapper {
async fn get(
_: Lua,
Expand Down Expand Up @@ -158,7 +151,7 @@ impl LuaUserData for ClientWrapper {
}

/// Lua wrapper for [scraper::Html]
struct Scraper(Arc<Mutex<Html>>);
pub struct Scraper(Arc<Mutex<Html>>);
impl LuaUserData for Scraper {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_function("new", |_, raw_html: String| {
Expand All @@ -185,7 +178,7 @@ impl LuaUserData for Scraper {

/// Lua wrapper for [scraper::Element]
#[derive(Clone)]
struct ElementWrapper(String, Element);
pub struct ElementWrapper(String, Element);
unsafe impl Send for ElementWrapper {}
impl LuaUserData for ElementWrapper {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
Expand All @@ -198,157 +191,28 @@ impl LuaUserData for ElementWrapper {
}
}

fn add_engine(lua: &Lua, (name, callback): (String, LuaFunction)) -> LuaResult<()> {
pub fn add_engine(lua: &Lua, (name, callback): (String, LuaFunction)) -> LuaResult<()> {
lua.globals()
.get::<LuaTable>("__searched_engines__")?
.set(name, callback.clone())?;

Ok(())
}
fn stringify_params(_: &Lua, params: LuaTable) -> LuaResult<String> {
pub fn stringify_params(_: &Lua, params: LuaTable) -> LuaResult<String> {
Ok(params
.pairs::<String, String>()
.filter_map(|ent| ent.ok().map(|(k, v)| [k, v].join("&")))
.collect::<Vec<_>>()
.join("&"))
}
fn parse_json(lua: &Lua, raw: String) -> LuaResult<LuaValue> {
pub fn parse_json(lua: &Lua, raw: String) -> LuaResult<LuaValue> {
let json: serde_json::Value = serde_json::from_str(&raw).into_lua_err()?;
lua.to_value(&json)
}
fn fend_eval(_: &Lua, input: String) -> LuaResult<String> {
pub fn fend_eval(_: &Lua, input: String) -> LuaResult<String> {
Ok(fend_core::evaluate(&input, &mut Context::new())
.unwrap()
.get_main_result()
.to_string())
}

/// A single-threaded plugin engine
#[derive(Clone)]
pub struct PluginEngine {
lua: Lua,
client: Client,
#[cfg(not(feature = "hot_reload"))]
providers: ProvidersConfig,
}
impl PluginEngine {
/// Initialize a new engine for running plugins
pub async fn new(client: Client) -> Result<Self, Box<dyn Error>> {
#[cfg(not(feature = "hot_reload"))]
let providers = ProvidersConfig::load("plugins/providers.toml");

debug!("initializing plugin engine...");

let lua = Lua::new();

// Add Lua global variables we need
lua.globals()
.set("__searched_engines__", lua.create_table()?)?;

// Add Lua interfaces
lua.globals()
.set("Url", lua.create_proxy::<UrlWrapper>()?)?;
lua.globals().set("Query", lua.create_proxy::<Query>()?)?;
lua.globals()
.set("Client", lua.create_proxy::<ClientWrapper>()?)?;
lua.globals()
.set("Scraper", lua.create_proxy::<Scraper>()?)?;
lua.globals()
.set("Element", lua.create_proxy::<ElementWrapper>()?)?;

// Add standalone Lua functions
lua.globals()
.set("add_engine", lua.create_function(add_engine)?)?;
lua.globals()
.set("stringify_params", lua.create_function(stringify_params)?)?;
lua.globals()
.set("parse_json", lua.create_function(parse_json)?)?;
lua.globals()
.set("fend_eval", lua.create_function(fend_eval)?)?;

debug!("Initialized plugin engine! loading engines...");

// Load engines
Self::load_engines(&lua).await;

debug!("loaded engines!");

Ok(Self {
lua,
client,
#[cfg(not(feature = "hot_reload"))]
providers,
})
}

pub async fn load_engines(lua: &Lua) {
for path in read_dir("plugins/engines").unwrap() {
if let Ok(path) = path {
// Do war crime level code
let name = path.path().file_stem().expect("bad file path").to_str().expect("filename should be utf-8").to_string();

debug!("loading {name}...");
let load_st = Instant::now();

// Read the source code into buf
let mut buf = String::new();
let mut f = File::open(path.path()).unwrap();
f.read_to_string(&mut buf).unwrap();

lua.load(&buf).exec_async().await.unwrap();

debug!("loaded {name} in {:?}!", load_st.elapsed());
}
}
}

/// Process the given query
pub async fn search(&self, query: Query) -> Vec<crate::Result> {
#[cfg(feature = "hot_reload")]
Self::load_engines(&self.lua).await;

#[cfg(feature = "hot_reload")]
let providers = ProvidersConfig::load("plugins/providers.toml");
#[cfg(not(feature = "hot_reload"))]
let providers = &self.providers;

if let Some(provider) = providers.0.get(&query.provider) {
let engine = provider.engine.clone().unwrap_or_else(|| query.provider.clone());
let target = format!("searched::engine::{engine}");

// Get engine implementation
let eng_impl = self
.lua
.globals()
.get::<LuaTable>("__searched_engines__")
.unwrap()
.get::<LuaFunction>(engine)
.unwrap();

// Run engine for query
let results = eng_impl
.call_async::<Vec<LuaTable>>((
ClientWrapper(self.client.clone()),
query.clone(),
self.lua
.to_value(&provider.clone().extra.unwrap_or_default()),
))
.await;

match results {
Ok(results) => {
return results
.into_iter()
.map(|r| self.lua.from_value(LuaValue::Table(r)).unwrap())
.collect();
}
Err(err) => {
error!(target: &target, "failed to get results from provider {}: {}", query.provider, err);
}
}
}

Vec::new()
}
}

142 changes: 142 additions & 0 deletions src/lua_support/engine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use core::error::Error;
use std::{
fs::{read_dir, File},
io::Read,
time::Instant,
};

use mlua::prelude::*;
use reqwest::Client;

use super::api::*;
use crate::{config::ProvidersConfig, Query};

/// A single-threaded plugin engine
#[derive(Clone)]
pub struct PluginEngine {
lua: Lua,
client: Client,
#[cfg(not(feature = "hot_reload"))]
providers: ProvidersConfig,
}
impl PluginEngine {
/// Initialize a new engine for running plugins
pub async fn new(client: Client) -> Result<Self, Box<dyn Error>> {
#[cfg(not(feature = "hot_reload"))]
let providers = ProvidersConfig::load("plugins/providers.toml");

debug!("initializing plugin engine...");

let lua = Lua::new();

// Add Lua global variables we need
lua.globals()
.set("__searched_engines__", lua.create_table()?)?;

// Add Lua interfaces
lua.globals()
.set("Url", lua.create_proxy::<UrlWrapper>()?)?;
lua.globals().set("Query", lua.create_proxy::<Query>()?)?;
lua.globals()
.set("Client", lua.create_proxy::<ClientWrapper>()?)?;
lua.globals()
.set("Scraper", lua.create_proxy::<Scraper>()?)?;
lua.globals()
.set("Element", lua.create_proxy::<ElementWrapper>()?)?;

// Add standalone Lua functions
lua.globals()
.set("add_engine", lua.create_function(add_engine)?)?;
lua.globals()
.set("stringify_params", lua.create_function(stringify_params)?)?;
lua.globals()
.set("parse_json", lua.create_function(parse_json)?)?;
lua.globals()
.set("fend_eval", lua.create_function(fend_eval)?)?;

debug!("Initialized plugin engine! loading engines...");

// Load engines
Self::load_engines(&lua).await;

debug!("loaded engines!");

Ok(Self {
lua,
client,
#[cfg(not(feature = "hot_reload"))]
providers,
})
}

pub async fn load_engines(lua: &Lua) {
for path in read_dir("plugins/engines").unwrap() {
if let Ok(path) = path {
// Do war crime level code
let name = path.path().file_stem().expect("bad file path").to_str().expect("filename should be utf-8").to_string();

debug!("loading {name}...");
let load_st = Instant::now();

// Read the source code into buf
let mut buf = String::new();
let mut f = File::open(path.path()).unwrap();
f.read_to_string(&mut buf).unwrap();

lua.load(&buf).exec_async().await.unwrap();

debug!("loaded {name} in {:?}!", load_st.elapsed());
}
}
}

/// Process the given query
pub async fn search(&self, query: Query) -> Vec<crate::Result> {
#[cfg(feature = "hot_reload")]
Self::load_engines(&self.lua).await;

#[cfg(feature = "hot_reload")]
let providers = ProvidersConfig::load("plugins/providers.toml");
#[cfg(not(feature = "hot_reload"))]
let providers = &self.providers;

if let Some(provider) = providers.0.get(&query.provider) {
let engine = provider.engine.clone().unwrap_or_else(|| query.provider.clone());
let target = format!("searched::engine::{engine}");

// Get engine implementation
let eng_impl = self
.lua
.globals()
.get::<LuaTable>("__searched_engines__")
.unwrap()
.get::<LuaFunction>(engine)
.unwrap();

// Run engine for query
let results = eng_impl
.call_async::<Vec<LuaTable>>((
ClientWrapper(self.client.clone()),
query.clone(),
self.lua
.to_value(&provider.clone().extra.unwrap_or_default()),
))
.await;

match results {
Ok(results) => {
return results
.into_iter()
.map(|r| self.lua.from_value(LuaValue::Table(r)).unwrap())
.collect();
}
Err(err) => {
error!(target: &target, "failed to get results from provider {}: {}", query.provider, err);
}
}
}

Vec::new()
}
}

4 changes: 4 additions & 0 deletions src/lua_support/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod api;
mod engine;

pub use engine::PluginEngine;
5 changes: 1 addition & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use axum::{
};
use log::LevelFilter;
//use reqwest::Client;
use searched::{
config::Config,
lua_api::PluginEngine,
};
use searched::lua_support::PluginEngine;
//use sled::Db;
use tokio::net::TcpListener;

Expand Down
Loading

0 comments on commit 2643203

Please sign in to comment.