use std::collections::HashSet; use std::path::{absolute, PathBuf}; use std::str::FromStr; use std::sync::Arc; use anyhow::Context; use anyhow::{anyhow, Error}; use clap::Parser; use clap::Subcommand; use futures::{stream, StreamExt, TryStreamExt}; use rusqlite::{params, OptionalExtension}; use scraper::{Html, Selector}; use snix_castore::blobservice::BlobService; use snix_castore::directoryservice::DirectoryService; use snix_castore::B3Digest; use snix_castore::{blobservice, directoryservice, import::fs::ingest_path}; use std::sync::Mutex; use tokio::io::{AsyncReadExt, BufReader}; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Semaphore; use tokio_stream::wrappers::ReceiverStream; use url::Url; #[derive(Clone, Debug)] enum Ingestable { Url(Url), Path(PathBuf), } #[derive(Debug, Clone)] enum IngestedWhen { Now, Before, } #[derive(Debug, Clone)] #[allow(dead_code)] struct Ingested { sample_id: u32, uri: String, blake3: B3Digest, epoch: u32, when: IngestedWhen, } #[derive(Clone)] enum FetchListingMessage { Ingested(Url, Ingested), Recurse(Url, usize), } impl std::fmt::Display for Ingestable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Ingestable::Url(url) => write!(f, "{}", url), Ingestable::Path(path_buf) => match path_buf.to_str() { Some(s) => write!(f, "{}", s), None => { panic!("PathBuf::to_str failed") } }, } } } fn parse_url_or_path(s: &str) -> Result { if s.is_empty() { Err(anyhow!("Empty path (url)")) } else if s.starts_with("./") || s.starts_with("/") { Ok(Ingestable::Path(PathBuf::from(s))) } else { let url = Url::parse(s)?; if url.scheme() == "file" { match url.to_file_path() { Ok(s) => Ok(Ingestable::Path(s)), Err(()) => Err(anyhow!( "parse_url_or_path: couldn't convert Url ({}) to Path", url )), } } else { Ok(Ingestable::Url(url)) } } } fn data_path() -> PathBuf { let xdg_data_dir = std::env::var("XDG_DATA_DIR") .and_then(|s| Ok(PathBuf::from(s))) .or_else(|_| -> Result { match std::env::home_dir() { Some(p) => Ok(p.join(".local/share")), None => Err(anyhow!("...")), // FIXME } }); match xdg_data_dir { Ok(p) => p.join("sidx"), Err(_) => PathBuf::from(".sidx"), } } fn default_castore_path() -> PathBuf { data_path().join("castore") } fn default_db_path() -> PathBuf { data_path().join("sidx.db") } #[derive(Subcommand)] enum Command { Ingest { #[clap(value_parser = parse_url_or_path, num_args = 1)] inputs: Vec, }, FetchListing { #[clap(value_parser, long, default_value_t = 5)] max_depth: usize, #[clap(value_parser, num_args = 1)] inputs: Vec, }, } #[derive(Parser)] struct Cli { #[clap(short, long, action)] refetch: bool, #[clap(short, long, value_parser, default_value_t = 4)] max_parallel: usize, #[clap(short, long, value_parser, default_value_os_t = default_db_path())] db_path: PathBuf, #[clap(short, long, value_parser, default_value_os_t = default_castore_path())] castore_path: PathBuf, #[command(subcommand)] command: Option, } struct SidxContext where BS: blobservice::BlobService + Clone + Send + 'static, DS: directoryservice::DirectoryService + Clone + Send + 'static, { refetch: bool, max_parallel: usize, http: reqwest::Client, con: Arc>, blob_service: BS, dir_service: DS, } async fn open_context( refetch: bool, max_parallel: usize, db_path: PathBuf, castore_path: PathBuf, ) -> SidxContext, Arc> { if let Some(p) = db_path.parent() { let _ = std::fs::create_dir_all(p); } let con = rusqlite::Connection::open(&db_path).expect("Failed to construct Database object"); con.execute_batch(include_str!("q/init.sql")) .expect("Failed to execute init.sql"); let castore_path = absolute(castore_path).expect("Failed to canonicalize castore_path"); let blob_service = blobservice::from_addr(&std::format!( "objectstore+file://{}", castore_path .join("blob") .to_str() .expect("Path::to_str unexpectedly broken") )) .await .expect("Couldn't initialize .castore/blob"); let dir_service = directoryservice::from_addr(&std::format!( "objectstore+file://{}", castore_path .join("directory") .to_str() .expect("Path::to_str unexpectedly broken") )) .await .expect("Couldn't initialize .castore/directory"); SidxContext::, Arc> { refetch, max_parallel, http: reqwest::Client::new(), con: Arc::new(Mutex::new(con)), blob_service, dir_service, } } impl SidxContext { async fn db_latest_download(&self, uri: &str) -> Result, Error> { let lock = self.con.lock().unwrap(); let mut find_sample = lock .prepare_cached(include_str!("q/latest-download.sql")) .expect("Failed to prepare latest-download.sql"); find_sample .query_row(params![uri], |r| <(u32, String, u32)>::try_from(r)) .optional() .context("db_latest_download.sql") .and_then(|maybe_triple| match maybe_triple { Some((sample_id, blake3, epoch)) => Ok(Some(Ingested { sample_id, uri: uri.to_string(), blake3: B3Digest::from_str(&blake3)?, epoch, when: IngestedWhen::Before, })), None => Ok(None), }) } async fn db_add_sample(&self, uri: &str, blake3: &str) -> Result<(u32, u32), rusqlite::Error> { let lock = self.con.lock().unwrap(); let mut add_sample = lock .prepare_cached(include_str!("q/add-sample.sql")) .expect("Failed to prepare add-sample.sql"); add_sample.query_row(params![uri, blake3], |row| <(u32, u32)>::try_from(row)) } async fn db_add_blob(&self, blake3: &str, n_bytes: u64) -> Result { let lock = self.con.lock().unwrap(); let mut add_blob = lock .prepare_cached(include_str!("q/upsert-blob.sql")) .expect("Failed to prepare upsert-blob.sql"); add_blob.execute(params![blake3, n_bytes,]) } async fn db_add_uri(&self, uri: &str) -> Result { let lock = self.con.lock().unwrap(); let mut add_uri = lock .prepare_cached(include_str!("q/upsert-uri.sql")) .expect("Failed to prepare upsert-uri.sql"); add_uri.execute(params![uri]) } async fn record_ingested_node( &self, uri: &str, blake3: &snix_castore::B3Digest, n_bytes: u64, ) -> Result { let digest64 = format!("{}", blake3); self.db_add_blob(&digest64, n_bytes).await?; self.db_add_uri(&uri).await?; let (sample_id, epoch) = self.db_add_sample(&uri, &digest64).await?; Ok(Ingested { sample_id, uri: uri.to_string(), blake3: blake3.clone(), epoch, when: IngestedWhen::Now, }) } async fn download_no_cache(&self, uri: &Url) -> Result { let uri_s = uri.to_string(); let res = self .http .get(uri.clone()) .send() .await .context(format!("Request::send failed early for {:?}", uri))? .error_for_status()?; let mut r = tokio_util::io::StreamReader::new(res.bytes_stream().map_err(std::io::Error::other)); let mut w = self.blob_service.open_write().await; let n_bytes = match tokio::io::copy(&mut r, &mut w).await { Ok(n) => n, Err(e) => { return Err(anyhow!( "tokio::io::copy failed for uri={} with {}", uri_s, e )); } }; let digest = w.close().await?; self.record_ingested_node(&uri_s, &digest, n_bytes).await } async fn download(&self, uri: &Url) -> Result { if self.refetch { self.download_no_cache(&uri).await } else { match self.db_latest_download(&uri.to_string()).await? { Some(ingested) => Ok(ingested), None => self.download_no_cache(&uri).await, } } } async fn ingest(&self, inputs: &Vec) -> Vec, Error>> { let samples = stream::iter(inputs.iter().map(|uri| { let blob_service = &self.blob_service; let dir_service = &self.dir_service; async move { let uri_s = uri.to_string(); let latest_download = self.db_latest_download(&uri_s).await?; if latest_download.is_some() { return Ok(latest_download); } match uri { Ingestable::Path(path) => { match ingest_path::<_, _, _, &[u8]>(&blob_service, &dir_service, path, None) .await? { snix_castore::Node::Directory { digest, size } => self .record_ingested_node(&uri_s, &digest, size) .await .map(Some), snix_castore::Node::File { digest, size, executable: _, } => self .record_ingested_node(&uri_s, &digest, size) .await .map(Some), snix_castore::Node::Symlink { target: _ } => { Err(anyhow!("TODO: Figure out what to do with symlink roots")) } } } Ingestable::Url(url) => self.download(url).await.map(Some), } } })) .buffer_unordered(self.max_parallel) .collect::, _>>>() .await; samples } fn extract_hrefs(content: &str) -> Result, Error> { let sel = Selector::parse("a").map_err(|e| anyhow!(e.to_string()))?; let html = Html::parse_document(&content); Ok(html .select(&sel) .flat_map(|elt| elt.value().attr("href")) .map(|s| s.to_string()) .collect::>()) } async fn fetch_from_listing_impl( self: Arc, url: Url, max_depth: usize, tx: Sender, ) -> Result<(), Error> { eprintln!("Downloading {:?}", url.to_string()); let root = self.download(&url).await?; tx.send(FetchListingMessage::Ingested(url.clone(), root.clone())) .await .context("Stopped accepting tasks before processing an Ingested notification")?; if max_depth <= 0 { return Ok(()); } /* TODO: no need to load blobs to memory unless you know they're text/html */ match self.blob_service.open_read(&root.blake3).await? { Some(mut reader) => { let content = { let mut br = BufReader::new(&mut *reader); let mut content = String::new(); br.read_to_string(&mut content).await?; content }; let hrefs = Self::extract_hrefs(&content).unwrap_or(vec![]); /* max_depth > 0 here */ for href in hrefs { let next_url = url.join(&href).context("Constructing next_url")?; tx.send(FetchListingMessage::Recurse( next_url.clone(), max_depth - 1, )) .await .context("Stopped accepting tasks before finishing all hrefs")?; } Ok(()) } None => Err(anyhow!("Couldn't read the ingested blob")), } } async fn fetch_from_listing( self: Arc, url: Url, max_depth: usize, ) -> ReceiverStream { let mq_size = 10; /* TODO: move task queue to e.g. sqlite */ let (tx, mut rx) = channel(mq_size); let (out_tx, out_rx) = channel(mq_size); let semaphore = Arc::new(Semaphore::new(self.max_parallel)); tokio::spawn({ async move { let mut seen: HashSet = HashSet::new(); tx.send(FetchListingMessage::Recurse(url, max_depth)) .await .expect("fetch_from_listing failed populating the queue"); while let Some(m) = rx.recv().await { match m { FetchListingMessage::Ingested(_url, ingested) => { out_tx .send(ingested) .await .expect("ReceiverStream failed to accept an Ingestable"); } FetchListingMessage::Recurse(url, max_depth) => { if max_depth > 0 && !seen.contains(&url.to_string()) { seen.insert(url.to_string()); tokio::spawn({ let s = self.clone(); let url = url.clone(); let tx = tx.clone(); let semaphore = semaphore.clone(); async move { let _permit = semaphore.acquire(); s.fetch_from_listing_impl(url, max_depth, tx).await } }); } } } } } }); ReceiverStream::new(out_rx) } } #[tokio::main] async fn main() { let args = Cli::parse(); let _cwd = std::env::current_dir().expect("Couldn't get CWD"); let _host_name = std::env::var("HOSTNAME").map_or(None, Some); let ctx = Arc::new( open_context( args.refetch, args.max_parallel, args.db_path, args.castore_path, ) .await, ); match args.command { Some(Command::Ingest { inputs }) => { let samples = ctx.ingest(&inputs).await; for s in samples { match s { Err(e) => { eprintln!("Failed to fetch: {}", e); } Ok(None) => {} Ok(Some(ingested)) => { eprintln!("{:?}", ingested) } } } } Some(Command::FetchListing { max_depth, inputs }) => { let ingested: Vec = stream::iter(inputs) .then(async |i| { let i = i.clone(); ctx.clone().fetch_from_listing(i, max_depth).await }) .flatten_unordered(args.max_parallel) .collect() .await; for i in ingested { eprintln!("{:?}", i); } } None => {} } }