use std::{collections::HashMap, ops::Deref as _}; use tokio_stream::StreamExt as _; use tracing::{Instrument, debug, debug_span, error, trace, warn}; use zbus::{ Connection, ObjectServer, fdo::Result, names::OwnedUniqueName, object_server::SignalEmitter, proxy::ProxyImpl, zvariant::{self, OwnedObjectPath}, }; pub const RESPONSE_SUCCESS: u32 = 0; // pub const RESPONSE_CANCELLED: u32 = 1; pub const RESPONSE_OTHER: u32 = 2; /// A handler for the org.freedesktop.portal.Request interface, proxying to another /// instance of the same interface. struct Request { host: RequestProxy<'static>, } #[zbus::interface( name = "org.freedesktop.portal.Request", proxy(default_service = "org.freedesktop.portal.Desktop") )] impl Request { #[zbus(signal)] async fn response( signal_emitter: &SignalEmitter<'_>, response: u32, results: HashMap<&str, zvariant::Value<'_>>, ) -> zbus::Result<()>; async fn close(&self) -> Result<()> { self.host.close().await } } pub trait ResultTransformer { fn apply<'a>( self, response: u32, results: HashMap<&'a str, zvariant::Value<'a>>, ) -> impl std::future::Future>)>> + std::marker::Send; } impl ResultTransformer for () { async fn apply<'a>( self, response: u32, results: HashMap<&'a str, zvariant::Value<'a>>, ) -> Result<(u32, HashMap<&'a str, zvariant::Value<'a>>)> { Ok((response, results)) } } pub struct ReqHandler { token: String, sender: Option, conn: Connection, server: ObjectServer, host_conn: Connection, transformer: T, } impl ReqHandler<()> { pub fn prepare<'a>( host: &impl ProxyImpl<'a>, hdr: zbus::message::Header<'_>, server: &ObjectServer, conn: &Connection, options: &HashMap<&str, zvariant::Value<'_>>, ) -> Self { ReqHandler { token: get_token(options), sender: hdr.sender().map(|s| s.to_owned().into()), conn: conn.to_owned(), server: server.to_owned(), host_conn: host.inner().connection().to_owned(), transformer: (), } } } impl ReqHandler { pub fn with_transform(self, transformer: T1) -> ReqHandler { ReqHandler { transformer, token: self.token, sender: self.sender, conn: self.conn, server: self.server, host_conn: self.host_conn, } } } impl ReqHandler { pub async fn perform( self, call: impl AsyncFnOnce() -> Result, ) -> Result { let sender = self.sender.ok_or_else(|| zbus::Error::MissingField)?; let sender_str = sender.trim_start_matches(':').replace('.', "_"); let token = self.token; let path = zvariant::ObjectPath::try_from(format!( "/org/freedesktop/portal/desktop/request/{sender_str}/{token}" )) .map_err(zbus::Error::from)?; let host_path = call().await?; let imp = Request { host: RequestProxy::builder(&self.host_conn) .path(host_path)? .build() .await?, }; let stream = imp.host.receive_response().await?; if !self.server.at(&path, imp).await? { return Err(zbus::fdo::Error::Failed( "Duplicate request object path".to_owned(), )); } let path_1: OwnedObjectPath = path.clone().into(); let sender = sender.to_owned().into(); tokio::spawn( forward_response(stream, self.conn.clone(), path_1, sender, self.transformer) .instrument(debug_span!("response proxy", ?path)), ); Ok(path.into()) } } fn get_token(options: &HashMap<&str, zvariant::Value<'_>>) -> String { match options.get("handle_token") { Some(zvariant::Value::Str(str)) => { trace!("extracted token from handle_token option"); return String::from(str.deref()); } Some(value) => warn!(?value, "handle_token option provided but not a string"), None => trace!("handle_token not provided"), }; use rand::distr::{Alphanumeric, SampleString}; Alphanumeric.sample_string(&mut rand::rng(), 16) } async fn forward_response( mut stream: ResponseStream, conn: Connection, path: zvariant::OwnedObjectPath, sender: zbus::names::OwnedUniqueName, transform: impl ResultTransformer, ) -> Result<()> { let signal_emitter = SignalEmitter::new(&conn, path)? .set_destination(zbus::names::BusName::Unique(sender.into())) .into_owned(); let Some(resp) = stream.next().await else { debug!("response stream gone"); return Ok(()); }; debug!(?resp, "got resp"); let (response, results) = match resp.0.deserialize() { Ok((response, results)) => match transform.apply(response, results).await { Ok(res) => res, Err(err) => { error!(%err, "transform error"); (RESPONSE_OTHER, HashMap::new()) } }, Err(err) => { error!(%err, "signal body type mismatch"); (RESPONSE_OTHER, HashMap::new()) } }; if let Err(err) = Request::response(&signal_emitter, response, results).await { error!(%err, "signal forwarding failed"); } Ok(()) }