sidx/src/main.rs

609 lines
20 KiB
Rust
Raw Normal View History

use std::collections::HashSet;
use std::path::{absolute, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context;
use anyhow::{anyhow, Error};
use clap::Parser;
2025-04-20 08:59:50 +00:00
use clap::Subcommand;
use futures::{stream, StreamExt, TryStreamExt};
use rusqlite::{params, OptionalExtension};
use scraper::{Html, Selector};
use snix_castore::blobservice::BlobService;
use snix_castore::directoryservice::DirectoryService;
use snix_castore::B3Digest;
use snix_castore::{blobservice, directoryservice, import::fs::ingest_path};
use std::sync::Mutex;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::Semaphore;
use tokio_stream::wrappers::ReceiverStream;
use url::Url;
#[derive(Clone, Debug)]
enum Ingestable {
Url(Url),
Path(PathBuf),
}
#[derive(Debug, Clone)]
enum SampledWhen {
Now,
Before,
}
#[derive(Debug, Clone)]
struct SizedBlob {
hash: B3Digest,
n_bytes: u64,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct Sampled {
sample_id: u32,
uri: String,
blob: Option<SizedBlob>,
http_status: Option<u16>,
epoch: u32,
when: SampledWhen,
}
#[derive(Clone)]
enum FetchListingMessage {
Sampled(Url, Sampled),
Recurse(Url, usize, Sender<FetchListingMessage>),
}
impl std::fmt::Display for Ingestable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ingestable::Url(url) => write!(f, "{}", url),
Ingestable::Path(path_buf) => match path_buf.to_str() {
Some(s) => write!(f, "{}", s),
None => {
panic!("PathBuf::to_str failed")
}
},
}
}
}
fn parse_url_or_path(s: &str) -> Result<Ingestable, Error> {
if s.is_empty() {
Err(anyhow!("Empty path (url)"))
} else if s.starts_with("./") || s.starts_with("/") {
Ok(Ingestable::Path(PathBuf::from(s)))
} else {
let url = Url::parse(s)?;
if url.scheme() == "file" {
match url.to_file_path() {
Ok(s) => Ok(Ingestable::Path(s)),
Err(()) => Err(anyhow!(
"parse_url_or_path: couldn't convert Url ({}) to Path",
url
)),
}
} else {
Ok(Ingestable::Url(url))
}
}
}
fn data_path() -> PathBuf {
let xdg_data_dir = std::env::var("XDG_DATA_DIR")
.and_then(|s| Ok(PathBuf::from(s)))
.or_else(|_| -> Result<PathBuf, Error> {
match std::env::home_dir() {
Some(p) => Ok(p.join(".local/share")),
None => Err(anyhow!("...")), // FIXME
}
});
match xdg_data_dir {
Ok(p) => p.join("sidx"),
Err(_) => PathBuf::from(".sidx"),
}
}
fn default_castore_path() -> PathBuf {
data_path().join("castore")
}
fn default_db_path() -> PathBuf {
data_path().join("sidx.db")
}
2025-04-20 08:59:50 +00:00
#[derive(Subcommand)]
enum Command {
Ingest {
#[clap(value_parser = parse_url_or_path, num_args = 1)]
inputs: Vec<Ingestable>,
},
FetchListing {
#[clap(value_parser, long, default_value_t = 5)]
max_depth: usize,
#[clap(value_parser, long, default_value_t = 1024 * 1024)]
html_max_bytes: u64,
#[clap(value_parser, num_args = 1)]
inputs: Vec<Url>,
},
ParseUrl {
#[clap(value_parser, num_args = 1)]
url: Vec<Url>,
},
2025-04-20 08:59:50 +00:00
}
#[derive(Parser)]
struct Cli {
#[clap(short, long, action)]
refetch: bool,
#[clap(short, long, value_parser, default_value_t = 2)]
max_parallel: usize,
#[clap(short, long, value_parser, default_value_os_t = default_db_path())]
db_path: PathBuf,
#[clap(short, long, value_parser, default_value_os_t = default_castore_path())]
castore_path: PathBuf,
2025-04-20 08:59:50 +00:00
#[command(subcommand)]
command: Option<Command>,
}
struct SidxContext<BS, DS>
where
BS: blobservice::BlobService + Clone + Send + 'static,
DS: directoryservice::DirectoryService + Clone + Send + 'static,
{
2025-04-20 08:59:50 +00:00
refetch: bool,
max_parallel: usize,
http: reqwest::Client,
http_semaphore: Arc<Semaphore>,
con: Arc<Mutex<rusqlite::Connection>>,
2025-04-20 08:59:50 +00:00
blob_service: BS,
dir_service: DS,
}
async fn open_context(
refetch: bool,
max_parallel: usize,
db_path: PathBuf,
castore_path: PathBuf,
) -> SidxContext<Arc<dyn BlobService>, Arc<dyn DirectoryService>> {
if let Some(p) = db_path.parent() {
2025-04-20 08:59:50 +00:00
let _ = std::fs::create_dir_all(p);
}
2025-04-20 08:59:50 +00:00
let con = rusqlite::Connection::open(&db_path).expect("Failed to construct Database object");
2025-04-28 20:42:35 +00:00
con.execute_batch(include_str!("q/sidx-init.sql"))
.expect("Failed to execute sidx-init.sql");
let castore_path = absolute(castore_path).expect("Failed to canonicalize castore_path");
2025-04-20 08:59:50 +00:00
let blob_service = blobservice::from_addr(&std::format!(
"objectstore+file://{}",
castore_path
.join("blob")
.to_str()
.expect("Path::to_str unexpectedly broken")
))
.await
.expect("Couldn't initialize .castore/blob");
let dir_service = directoryservice::from_addr(&std::format!(
"objectstore+file://{}",
castore_path
.join("directory")
.to_str()
.expect("Path::to_str unexpectedly broken")
))
.await
.expect("Couldn't initialize .castore/directory");
SidxContext::<Arc<dyn BlobService>, Arc<dyn DirectoryService>> {
refetch,
max_parallel,
http: reqwest::Client::new(),
http_semaphore: Arc::new(Semaphore::new(max_parallel)),
con: Arc::new(Mutex::new(con)),
blob_service,
dir_service,
}
}
impl<BS: BlobService + Clone, DS: DirectoryService + Clone> SidxContext<BS, DS> {
async fn db_latest_download(&self, uri: &str) -> Result<Option<Sampled>, Error> {
let lock = self.con.lock().unwrap();
let mut find_sample = lock
.prepare_cached(include_str!("q/latest-download.sql"))
.expect("Failed to prepare latest-download.sql");
find_sample
.query_row(params![uri], |r| {
<(u32, String, u64, Option<u16>, u32)>::try_from(r)
})
.optional()
.context("db_latest_download.sql")
.and_then(|maybe_tuple| match maybe_tuple {
Some((sample_id, hash, n_bytes, http_code, epoch)) => Ok(Some(Sampled {
sample_id,
uri: uri.to_string(),
blob: Some(SizedBlob {
hash: B3Digest::from_str(&hash)?,
n_bytes,
}),
http_status: http_code,
epoch,
when: SampledWhen::Before,
})),
None => Ok(None),
})
}
async fn db_add_sample(
&self,
uri: &str,
hash: &Option<String>,
http_code: Option<u16>,
) -> Result<(u32, u32), Error> {
let lock = self.con.lock().expect("Couldn't lock mutex");
let mut add_sample = lock
.prepare_cached(include_str!("q/add-sample.sql"))
.context("Failed to prepare add-sample.sql")?;
Ok(add_sample.query_row(params![uri, hash, http_code], |row| {
<(u32, u32)>::try_from(row)
})?)
}
async fn db_add_blob(&self, hash: &str, n_bytes: u64) -> Result<usize, Error> {
let lock = self.con.lock().expect("db_add_blob: couldn't lock mutex?");
let mut add_blob = lock
.prepare_cached(include_str!("q/upsert-blob.sql"))
.context("Failed to prepare upsert-blob.sql")?;
Ok(add_blob.execute(params![hash, n_bytes,])?)
}
async fn db_add_uri(&self, uri: &str) -> Result<usize, Error> {
let lock = self.con.lock().unwrap();
let mut add_uri = lock
.prepare_cached(include_str!("q/upsert-uri.sql"))
.context("Failed to prepare upsert-uri.sql")?;
Ok(add_uri.execute(params![uri])?)
}
async fn record_ingested_node(
&self,
uri: &str,
blob: &Option<SizedBlob>,
http_code: Option<u16>,
) -> Result<Sampled, Error> {
let digest64 = if let Some(SizedBlob { hash, n_bytes }) = blob {
let digest64 = format!("{}", hash);
self.db_add_blob(&digest64, n_bytes.clone()).await?;
Some(digest64)
} else {
None
};
self.db_add_uri(&uri).await?;
let (sample_id, epoch) = self
.db_add_sample(&uri, &digest64, http_code.clone())
.await?;
Ok(Sampled {
sample_id,
uri: uri.to_string(),
blob: blob.clone(),
http_status: http_code,
epoch,
when: SampledWhen::Now,
})
}
async fn download_no_cache(&self, uri: &Url) -> Result<Sampled, Error> {
let _permit = self.http_semaphore.acquire().await.unwrap();
eprintln!("Downloading {:?}", uri.to_string());
let uri_s = uri.to_string();
let res = self
.http
.get(uri.clone())
.send()
.await
.context(format!("Request::send failed early for {:?}", uri))?;
let status = res.status();
let status_code = status.as_u16();
if status.is_success() {
let mut r = tokio_util::io::StreamReader::new(
res.bytes_stream().map_err(std::io::Error::other),
);
let mut w = self.blob_service.open_write().await;
let n_bytes = match tokio::io::copy(&mut r, &mut w).await {
Ok(n) => n,
Err(e) => {
return Err(anyhow!(
"tokio::io::copy failed for uri={} with {}",
uri_s,
e
));
}
};
let digest = w.close().await?;
self.record_ingested_node(
&uri_s,
&Some(SizedBlob {
hash: digest,
n_bytes,
}),
Some(status_code),
)
.await
} else {
self.record_ingested_node(&uri_s, &None, Some(status_code))
.await
}
}
async fn download(&self, uri: &Url) -> Result<Sampled, Error> {
if self.refetch {
self.download_no_cache(&uri).await
} else {
match self.db_latest_download(&uri.to_string()).await? {
Some(ingested) => Ok(ingested),
None => self.download_no_cache(&uri).await,
}
}
}
async fn ingest(&self, inputs: &Vec<Ingestable>) -> Vec<Result<Option<Sampled>, Error>> {
let samples = stream::iter(inputs.iter().map(|uri| {
let blob_service = &self.blob_service;
let dir_service = &self.dir_service;
async move {
let uri_s = uri.to_string();
let latest_download = self.db_latest_download(&uri_s).await?;
if latest_download.is_some() {
return Ok(latest_download);
}
match uri {
Ingestable::Path(path) => {
match ingest_path::<_, _, _, &[u8]>(&blob_service, &dir_service, path, None)
.await?
{
snix_castore::Node::Directory { digest, size } => self
.record_ingested_node(
&uri_s,
&Some(SizedBlob {
hash: digest,
n_bytes: size,
}),
None,
)
.await
.map(Some),
snix_castore::Node::File {
digest,
size,
executable: _,
} => self
.record_ingested_node(
&uri_s,
&Some(SizedBlob {
hash: digest,
n_bytes: size,
}),
None,
)
.await
.map(Some),
snix_castore::Node::Symlink { target: _ } => {
Err(anyhow!("TODO: Figure out what to do with symlink roots"))
}
}
}
Ingestable::Url(url) => self.download(url).await.map(Some),
}
}
}))
.buffer_unordered(self.max_parallel)
.collect::<Vec<Result<Option<Sampled>, _>>>()
.await;
samples
}
fn extract_hrefs(content: &str) -> Result<Vec<String>, Error> {
let sel = Selector::parse("a").map_err(|e| anyhow!(e.to_string()))?;
let html = Html::parse_document(&content);
Ok(html
.select(&sel)
.flat_map(|elt| elt.value().attr("href"))
.map(|s| s.to_string())
.collect::<Vec<_>>())
}
async fn fetch_from_listing_impl(
self: Arc<Self>,
url: Url,
max_depth: usize,
html_max_bytes: u64,
tx: Sender<FetchListingMessage>,
) -> Result<(), Error> {
let maybe_root = self.download(&url).await;
if let Err(ref e) = maybe_root {
eprintln!("Couldn't download {}: {:?}", url, e);
};
let root = maybe_root?;
tx.send(FetchListingMessage::Sampled(url.clone(), root.clone()))
.await
.context("Stopped accepting tasks before processing an Ingested notification")?;
if max_depth <= 0 {
return Ok(());
}
match root.blob {
None => Err(anyhow!(
"Couldn't download {}. Status code: {:?}",
url,
root.http_status
)),
Some(SizedBlob { hash, n_bytes }) => {
if n_bytes > html_max_bytes {
return Ok(());
}
match self.blob_service.open_read(&hash).await? {
Some(mut reader) => {
let content = {
let mut br = BufReader::new(&mut *reader);
let mut content = String::new();
br.read_to_string(&mut content).await?;
content
};
let hrefs = Self::extract_hrefs(&content).unwrap_or(vec![]);
/* max_depth > 0 here */
for href in hrefs.clone() {
let next_url = url.join(&href).context("Constructing next_url")?;
tx.send(FetchListingMessage::Recurse(
next_url.clone(),
max_depth - 1,
tx.clone(),
))
.await
.context("Stopped accepting tasks before finishing all hrefs")?;
}
{
let lock = self.con.lock().expect("Couldn't acquire Mutex?");
for href in hrefs {
let mut stmt =
lock.prepare_cached(include_str!("q/add-str.sql"))?;
stmt.execute(params!["href"])?;
let next_url = url.join(&href).context("Constructing next_url")?;
let mut stmt =
lock.prepare_cached(include_str!("q/add-uri-ref.sql"))?;
let digest64 = hash.to_string();
stmt.execute(params![digest64, next_url.to_string(), href])?;
}
};
Ok(())
}
None => Err(anyhow!("Couldn't read the ingested blob")),
}
}
}
}
async fn fetch_from_listing(
self: Arc<Self>,
url: Url,
max_depth: usize,
html_max_bytes: u64,
) -> ReceiverStream<Sampled> {
let mq_size = 10;
/* TODO: move task queue to e.g. sqlite */
let (tx, mut rx) = channel(mq_size);
let (out_tx, out_rx) = channel(mq_size);
tokio::spawn({
async move {
let mut seen: HashSet<String> = HashSet::new();
{
let tx_moved = tx;
tx_moved
.send(FetchListingMessage::Recurse(
url,
max_depth,
tx_moved.clone(),
))
.await
.expect("fetch_from_listing failed populating the queue");
};
while let Some(m) = rx.recv().await {
match m {
FetchListingMessage::Sampled(_url, ingested) => {
out_tx
.send(ingested)
.await
.expect("ReceiverStream failed to accept an Ingestable");
}
FetchListingMessage::Recurse(url, max_depth, tx) => {
if max_depth > 0 && !seen.contains(&url.to_string()) {
seen.insert(url.to_string());
tokio::spawn({
let s = self.clone();
let url = url.clone();
async move {
s.fetch_from_listing_impl(
url,
max_depth,
html_max_bytes,
tx,
)
.await
}
});
}
}
}
}
}
});
ReceiverStream::new(out_rx)
}
}
#[tokio::main]
async fn main() {
let args = Cli::parse();
let _cwd = std::env::current_dir().expect("Couldn't get CWD");
let _host_name = std::env::var("HOSTNAME").map_or(None, Some);
let ctx = Arc::new(
open_context(
args.refetch,
args.max_parallel,
args.db_path,
args.castore_path,
)
.await,
);
2025-04-20 08:59:50 +00:00
match args.command {
Some(Command::Ingest { inputs }) => {
let samples = ctx.ingest(&inputs).await;
2025-04-20 08:59:50 +00:00
for s in samples {
match s {
Err(e) => {
2025-04-20 17:53:54 +00:00
eprintln!("Failed to fetch: {}", e);
2025-04-20 08:59:50 +00:00
}
Ok(None) => {}
Ok(Some(ingested)) => {
2025-04-20 17:53:54 +00:00
eprintln!("{:?}", ingested)
2025-04-20 08:59:50 +00:00
}
}
}
}
Some(Command::FetchListing {
max_depth,
html_max_bytes,
inputs,
}) => {
let ingested: Vec<Sampled> = stream::iter(inputs)
.then(async |i| {
let i = i.clone();
ctx.clone()
.fetch_from_listing(i, max_depth, html_max_bytes)
.await
})
.flatten_unordered(args.max_parallel)
.collect()
.await;
for i in ingested {
eprintln!("{:?}", i);
}
}
Some(Command::ParseUrl { url: urls }) => {
for url in urls {
println!("{:?}", url);
}
}
2025-04-20 08:59:50 +00:00
None => {}
}
}