From e85ccd62ae27d8efe5cb99db9628f6a0dce92868 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 24 Apr 2026 07:50:03 -0400 Subject: [PATCH 1/5] elixir: restore start_tracing Signed-off-by: Nathan Perry Change-Id: Ibdfa3c10dd50379a43b52a07c1a713cd6a6a6964 --- Cargo.lock | 1 + ts_elixir/native/ts_elixir/Cargo.toml | 1 + ts_elixir/native/ts_elixir/src/lib.rs | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 79ad49dc..6f34ad7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4288,6 +4288,7 @@ dependencies = [ "tailscale", "tokio", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/ts_elixir/native/ts_elixir/Cargo.toml b/ts_elixir/native/ts_elixir/Cargo.toml index e5c660d6..94d7850c 100644 --- a/ts_elixir/native/ts_elixir/Cargo.toml +++ b/ts_elixir/native/ts_elixir/Cargo.toml @@ -17,6 +17,7 @@ tailscale = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } [lib] crate-type = ["cdylib"] diff --git a/ts_elixir/native/ts_elixir/src/lib.rs b/ts_elixir/native/ts_elixir/src/lib.rs index 5836de49..e1ef7e5d 100644 --- a/ts_elixir/native/ts_elixir/src/lib.rs +++ b/ts_elixir/native/ts_elixir/src/lib.rs @@ -4,10 +4,11 @@ use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, str::FromStr, - sync::{Arc, LazyLock}, + sync::{Arc, LazyLock, Once}, }; use rustler::{Encoder, NifResult, ResourceArc, Term}; +use tracing::level_filters::LevelFilter; mod config; mod tcp; @@ -103,6 +104,21 @@ where Ok(ResourceArc::new(t)) } +#[rustler::nif] +fn start_tracing() { + static TRACING_ONCE: Once = Once::new(); + + TRACING_ONCE.call_once(|| { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + }); +} + #[rustler::nif(schedule = "DirtyIo")] fn connect<'env>( env: rustler::Env<'env>, From 4b54688249158023c46095dd8e57793ed205d7d3 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 24 Apr 2026 07:50:03 -0400 Subject: [PATCH 2/5] elixir: fix deworkspace cargo.toml script It erroneously expected all deps to have `workspace = true`, now checks this correctly. Signed-off-by: Nathan Perry Change-Id: Iafe01153ca52b9f618c80666de5e6f186a6a6964 --- ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py b/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py index e5ed63e6..72492424 100644 --- a/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py +++ b/ts_elixir/native/ts_elixir/deworkspace_cargo_toml.py @@ -60,7 +60,7 @@ def main(): for dep in list(cargotoml[name].keys()): value = cargotoml[name][dep] - if type(value) == dict and value['workspace'] is True: + if type(value) == dict and value.get('workspace') is True: if args.repo_sha and (dep.startswith('tailscale') or dep.startswith('ts_')): value['git'] = f'https://github.com/tailscale/tailscale-rs' value['rev'] = args.repo_sha From e0f163846a397935dfdcf827f4f103bdde9b86f0 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 24 Apr 2026 07:50:03 -0400 Subject: [PATCH 3/5] elixir: factor out helpers and type conversions Signed-off-by: Nathan Perry Change-Id: Iafe01153ca52b9f618c80666de5e6f186a6a6964 --- ts_elixir/native/ts_elixir/src/config.rs | 4 +- ts_elixir/native/ts_elixir/src/erl_ip.rs | 93 +++++++++ ts_elixir/native/ts_elixir/src/helpers.rs | 31 +++ ts_elixir/native/ts_elixir/src/ip_or_self.rs | 51 +++++ ts_elixir/native/ts_elixir/src/lib.rs | 205 +++---------------- ts_elixir/native/ts_elixir/src/node_info.rs | 43 ++++ ts_elixir/native/ts_elixir/src/tcp.rs | 80 ++++---- ts_elixir/native/ts_elixir/src/udp.rs | 54 +++-- 8 files changed, 315 insertions(+), 246 deletions(-) create mode 100644 ts_elixir/native/ts_elixir/src/erl_ip.rs create mode 100644 ts_elixir/native/ts_elixir/src/helpers.rs create mode 100644 ts_elixir/native/ts_elixir/src/ip_or_self.rs create mode 100644 ts_elixir/native/ts_elixir/src/node_info.rs diff --git a/ts_elixir/native/ts_elixir/src/config.rs b/ts_elixir/native/ts_elixir/src/config.rs index 888c9b3a..ce3468f1 100644 --- a/ts_elixir/native/ts_elixir/src/config.rs +++ b/ts_elixir/native/ts_elixir/src/config.rs @@ -29,14 +29,14 @@ pub fn config_from_erl( config.key_state = value .decode::()? .try_into() - .map_err(|_| rustler::Error::Atom("badkeys"))?; + .map_err(|_| rustler::Error::BadArg)?; } if let Some(value) = erl_config.get(&atoms::control_url()) { config.control_server_url = value.decode::<&str>()?.parse().map_err(|e| { tracing::error!(error = %e, "parsing control server url"); - rustler::Error::Atom("bad_url") + rustler::Error::BadArg })?; } diff --git a/ts_elixir/native/ts_elixir/src/erl_ip.rs b/ts_elixir/native/ts_elixir/src/erl_ip.rs new file mode 100644 index 00000000..97d6027b --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/erl_ip.rs @@ -0,0 +1,93 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr, +}; + +use rustler::{Encoder, NifResult, Term}; + +/// Erlang-formatted IP. +/// +/// Supports decoding from either a string or `:inet` (tuple of octets or segments) format, +/// always encodes into the `:inet` format. +#[derive(Copy, Clone, Debug)] +pub struct ErlIp(pub IpAddr); + +impl From for ErlIp { + fn from(value: Ipv4Addr) -> Self { + Self(value.into()) + } +} + +impl From for ErlIp { + fn from(value: Ipv6Addr) -> Self { + Self(value.into()) + } +} + +impl From for ErlIp { + fn from(value: IpAddr) -> Self { + Self(value) + } +} + +impl From for IpAddr { + fn from(value: ErlIp) -> Self { + value.0 + } +} + +impl<'a> rustler::Decoder<'a> for ErlIp { + fn decode(ip: Term<'a>) -> NifResult { + if let Ok(tuple) = rustler::types::tuple::get_tuple(ip) { + if tuple.len() == 4 { + let mut octets = [0u8; 4]; + + for (i, elem) in tuple.into_iter().take(4).enumerate() { + octets[i] = elem.decode()?; + } + + return Ok(Self(Ipv4Addr::from_octets(octets).into())); + } + + if tuple.len() == 8 { + let mut segments = [0u16; 8]; + + for (i, elem) in tuple.into_iter().take(8).enumerate() { + segments[i] = elem.decode()?; + } + + return Ok(Self(Ipv6Addr::from_segments(segments).into())); + } + } + + if let Ok(s) = ip.decode::<&str>() { + let ip = IpAddr::from_str(s).map_err(|e| { + tracing::error!(error = %e, "parsing ip addr"); + + rustler::Error::BadArg + })?; + + return Ok(Self(ip)); + } + + Err(rustler::Error::BadArg) + } +} + +impl Encoder for ErlIp { + fn encode<'a>(&self, env: rustler::Env<'a>) -> Term<'a> { + match self.0 { + IpAddr::V4(ip) => { + let octets = ip.octets(); + (octets[0], octets[1], octets[2], octets[3]).encode(env) + } + IpAddr::V6(ip) => { + // rustler doesn't provide `impl Encoder` for 8-length tuples + let segments = ip.segments().map(|segment| segment.encode(env)); + + let tuple = rustler::types::tuple::make_tuple(env, &segments); + tuple.encode(env) + } + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/helpers.rs b/ts_elixir/native/ts_elixir/src/helpers.rs new file mode 100644 index 00000000..b9a324d6 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/helpers.rs @@ -0,0 +1,31 @@ +use std::{error::Error, fmt::Display, net::SocketAddr}; + +use rustler::{Encoder, NifResult, ResourceArc, Term}; + +use crate::{atoms, erl_ip::ErlIp}; + +pub type Result = std::result::Result>; + +/// Wrap the given [`rustler::Resource`] in a [`ResourceArc`] inside a [`NifResult`]. +pub fn ok_arc(t: T) -> NifResult> +where + T: rustler::Resource, +{ + Ok(ResourceArc::new(t)) +} + +/// Convert the argument into a [`rustler::Error`] by making it into a string. +pub fn term_err(e: impl Display) -> rustler::Error { + rustler::Error::Term(Box::new(e.to_string())) +} + +pub fn sockaddr_to_erl(addr: SocketAddr) -> (ErlIp, u16) { + (ErlIp(addr.ip()), addr.port()) +} + +pub fn erl_result(env: rustler::Env, r: Result) -> NifResult { + match r { + Ok(t) => Ok((atoms::ok(), t).encode(env)), + Err(e) => Err(term_err(e)), + } +} diff --git a/ts_elixir/native/ts_elixir/src/ip_or_self.rs b/ts_elixir/native/ts_elixir/src/ip_or_self.rs new file mode 100644 index 00000000..f4482d5d --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/ip_or_self.rs @@ -0,0 +1,51 @@ +use std::net::IpAddr; + +use rustler::{Error, NifResult, Term}; + +use crate::{atoms, erl_ip::ErlIp}; + +/// A literal IP address, the atom `:ip4`, or the atom `:ip6`. +/// +/// The latter two mean this node's IPv4 or IPv6 address, respectively. +pub enum IpOrSelf { + Ip(ErlIp), + SelfV4, + SelfV6, +} + +impl<'a> rustler::Decoder<'a> for IpOrSelf { + fn decode(ip: Term<'a>) -> NifResult { + if let Ok(ip) = ip.decode::() { + return Ok(Self::Ip(ip)); + } + + let atom = ip.decode::()?; + if atom == atoms::ip4() { + return Ok(Self::SelfV4); + } + + if atom == atoms::ip6() { + return Ok(Self::SelfV6); + } + + Err(Error::BadArg) + } +} + +impl IpOrSelf { + pub async fn resolve(&self, dev: &tailscale::Device) -> NifResult { + match self { + IpOrSelf::Ip(ip) => Ok(ip.0), + IpOrSelf::SelfV4 => dev + .ipv4_addr() + .await + .map(Into::into) + .map_err(|e| Error::Term(Box::new(e.to_string()))), + IpOrSelf::SelfV6 => dev + .ipv6_addr() + .await + .map(Into::into) + .map_err(|e| Error::Term(Box::new(e.to_string()))), + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/lib.rs b/ts_elixir/native/ts_elixir/src/lib.rs index e1ef7e5d..d0a52219 100644 --- a/ts_elixir/native/ts_elixir/src/lib.rs +++ b/ts_elixir/native/ts_elixir/src/lib.rs @@ -2,8 +2,6 @@ use std::{ collections::HashMap, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - str::FromStr, sync::{Arc, LazyLock, Once}, }; @@ -11,13 +9,22 @@ use rustler::{Encoder, NifResult, ResourceArc, Term}; use tracing::level_filters::LevelFilter; mod config; +mod erl_ip; +mod helpers; +mod ip_or_self; +mod node_info; mod tcp; mod udp; +use config::Keystate; +use erl_ip::ErlIp; +use helpers::{Result, ok_arc, sockaddr_to_erl, term_err}; +use ip_or_self::IpOrSelf; +use node_info::NodeInfo; use tcp::{TcpListener, TcpStream}; use udp::UdpSocket; -use crate::config::Keystate; +use crate::helpers::erl_result; mod atoms { rustler::atoms! { @@ -33,49 +40,6 @@ struct Device { inner: Arc, } -#[derive(rustler::NifStruct)] -#[module = "Tailscale.NodeInfo"] -struct NodeInfo<'a> { - id: i64, - stable_id: String, - hostname: String, - tailnet: Option, - tags: Vec, - tailnet_addresses: Vec>, - derp_region: Option, - node_key: String, - disco_key: Option, - machine_key: Option, - underlay_addresses: Vec>, -} - -impl<'a> NodeInfo<'a> { - fn from_node(env: rustler::Env<'a>, value: tailscale::NodeInfo) -> Self { - Self { - id: value.id, - stable_id: value.stable_id.0, - hostname: value.hostname, - tailnet: value.tailnet, - tags: value.tags, - tailnet_addresses: vec![ - ip_to_erl(env, value.tailnet_address.ipv4.addr()), - ip_to_erl(env, value.tailnet_address.ipv6.addr()), - ], - derp_region: value.derp_region.map(|x| x.0.get()), - node_key: value.node_key.to_string(), - disco_key: value.disco_key.as_ref().map(ToString::to_string), - machine_key: value.machine_key.as_ref().map(ToString::to_string), - underlay_addresses: value - .underlay_addresses - .into_iter() - .map(|x| (ip_to_erl(env, x.ip()), x.port()).encode(env)) - .collect(), - } - } -} - -type Result = core::result::Result>; - #[rustler::resource_impl] impl rustler::Resource for Device {} @@ -90,20 +54,6 @@ static TOKIO_RUNTIME: LazyLock = LazyLock::new(|| { rt }); -fn erl_result(env: rustler::Env, r: Result) -> Term { - match r { - Ok(t) => (atoms::ok(), t).encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } -} - -fn ok_arc(t: T) -> Result> -where - T: rustler::Resource, -{ - Ok(ResourceArc::new(t)) -} - #[rustler::nif] fn start_tracing() { static TRACING_ONCE: Once = Once::new(); @@ -123,25 +73,24 @@ fn start_tracing() { fn connect<'env>( env: rustler::Env<'env>, opts: HashMap>, -) -> NifResult<(rustler::Atom, Term<'env>)> { +) -> NifResult> { let (config, auth_key) = config::config_from_erl(&opts)?; let dev = TOKIO_RUNTIME.block_on(async move { - let dev = tailscale::Device::new(&config, auth_key).await?; + let dev = tailscale::Device::new(&config, auth_key) + .await + .map_err(term_err)?; ok_arc(Device { inner: Arc::new(dev), }) }); - match dev { - Ok(dev) => Ok((atoms::ok(), dev.encode(env))), - Err(e) => Err(rustler::Error::Term(Box::new(e.to_string()))), - } + dev.map(|d| d.encode(env)) } #[rustler::nif(schedule = "DirtyIo")] -fn load_key_file(env: rustler::Env, path: &str) -> impl Encoder { +fn load_key_file(env: rustler::Env, path: &str) -> NifResult { let result = TOKIO_RUNTIME .block_on(tailscale::config::load_key_file(path, Default::default())) .map(Keystate::from) @@ -151,21 +100,21 @@ fn load_key_file(env: rustler::Env, path: &str) -> impl Encoder { } #[rustler::nif(schedule = "DirtyIo")] -fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> impl Encoder { +fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> NifResult { let dev = dev.inner.clone(); let addr = TOKIO_RUNTIME.block_on(dev.ipv4_addr()); - erl_result(env, addr.map(|ip| ip_to_erl(env, ip)).map_err(Into::into)) + erl_result(env, addr.map(|ip| ErlIp(ip.into())).map_err(Into::into)) } #[rustler::nif(schedule = "DirtyIo")] -fn ipv6_addr(env: rustler::Env<'_>, dev: ResourceArc) -> impl Encoder { +fn ipv6_addr(dev: ResourceArc) -> NifResult { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(dev.ipv6_addr()) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(ip) => (atoms::ok(), ip_to_erl(env, ip)).encode(env), - } + TOKIO_RUNTIME + .block_on(dev.ipv6_addr()) + .map(ErlIp::from) + .map_err(term_err) } #[rustler::nif(schedule = "DirtyIo")] @@ -176,7 +125,7 @@ fn peer_by_name(env: rustler::Env<'_>, dev: ResourceArc, name: &str) -> match TOKIO_RUNTIME.block_on(async move { dev.peer_by_name(&name).await }) { Err(e) => (atoms::error(), e.to_string()).encode(env), Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), + Ok(Some(peer)) => (atoms::ok(), NodeInfo::from(peer)).encode(env), } } @@ -186,127 +135,35 @@ fn self_node(env: rustler::Env<'_>, dev: ResourceArc) -> impl Encoder { match TOKIO_RUNTIME.block_on(async move { dev.self_node().await }) { Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(peer) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), + Ok(peer) => (atoms::ok(), NodeInfo::from(peer)).encode(env), } } #[rustler::nif(schedule = "DirtyIo")] -fn peer_by_tailnet_ip(env: rustler::Env<'_>, dev: ResourceArc, ip: Term) -> impl Encoder { +fn peer_by_tailnet_ip(env: rustler::Env<'_>, dev: ResourceArc, ip: ErlIp) -> impl Encoder { let dev = dev.inner.clone(); - let Some(ip) = ip_from_erl(ip) else { - return env.error_tuple("invalid ip"); - }; - match TOKIO_RUNTIME.block_on(async move { dev.peer_by_tailnet_ip(ip).await }) { + match TOKIO_RUNTIME.block_on(async move { dev.peer_by_tailnet_ip(ip.0).await }) { Err(e) => (atoms::error(), e.to_string()).encode(env), Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from_node(env, peer)).encode(env), + Ok(Some(peer)) => (atoms::ok(), NodeInfo::from(peer)).encode(env), } } #[rustler::nif(schedule = "DirtyIo")] -fn peers_with_route(env: rustler::Env<'_>, dev: ResourceArc, ip: Term) -> impl Encoder { +fn peers_with_route(env: rustler::Env<'_>, dev: ResourceArc, ip: ErlIp) -> impl Encoder { let dev = dev.inner.clone(); - let Some(ip) = ip_from_erl(ip) else { - return env.error_tuple("invalid ip"); - }; - match TOKIO_RUNTIME.block_on(async move { dev.peers_with_route(ip).await }) { + match TOKIO_RUNTIME.block_on(async move { dev.peers_with_route(ip.0).await }) { Err(e) => (atoms::error(), e.to_string()).encode(env), Ok(peers) => ( atoms::ok(), - peers - .into_iter() - .map(|x| NodeInfo::from_node(env, x)) - .collect::>(), + peers.into_iter().map(NodeInfo::from).collect::>(), ) .encode(env), } } -fn ip_to_erl(env: rustler::Env, ip: impl Into) -> Term { - match ip.into() { - IpAddr::V4(ip) => { - let octets = ip.octets(); - (octets[0], octets[1], octets[2], octets[3]).encode(env) - } - IpAddr::V6(ip) => { - // rustler doesn't provide `impl Encoder` for 8-length tuples - let segments = ip.segments().map(|segment| segment.encode(env)); - - let tuple = rustler::types::tuple::make_tuple(env, &segments); - tuple.encode(env) - } - } -} - -enum IpOrSelf { - Ip(IpAddr), - SelfV4, - SelfV6, -} - -impl IpOrSelf { - pub fn new(ip: Term<'_>) -> Option { - if let Some(ip) = ip_from_erl(ip) { - return Some(Self::Ip(ip)); - } - - let atom = ip.decode::().ok()?; - if atom == atoms::ip4() { - return Some(Self::SelfV4); - } - - if atom == atoms::ip6() { - return Some(Self::SelfV6); - } - - None - } - - pub async fn resolve(&self, dev: &tailscale::Device) -> Result { - match self { - IpOrSelf::Ip(ip) => Ok(*ip), - IpOrSelf::SelfV4 => dev.ipv4_addr().await.map(Into::into).map_err(Into::into), - IpOrSelf::SelfV6 => dev.ipv6_addr().await.map(Into::into).map_err(Into::into), - } - } -} - -fn ip_from_erl(ip: Term) -> Option { - if let Ok(tuple) = rustler::types::tuple::get_tuple(ip) { - if tuple.len() == 4 { - let mut octets = [0u8; 4]; - - for (i, elem) in tuple.into_iter().take(4).enumerate() { - octets[i] = elem.decode().ok()?; - } - - return Some(Ipv4Addr::from_octets(octets).into()); - } - - if tuple.len() == 8 { - let mut segments = [0u16; 8]; - - for (i, elem) in tuple.into_iter().take(8).enumerate() { - segments[i] = elem.decode().ok()?; - } - - return Some(Ipv6Addr::from_segments(segments).into()); - } - } - - if let Ok(s) = ip.decode::<&str>() { - return IpAddr::from_str(s).ok(); - } - - None -} - -fn sockaddr_to_erl(env: rustler::Env, addr: SocketAddr) -> impl Encoder { - (ip_to_erl(env, addr.ip()), addr.port()) -} - fn load(env: rustler::Env, _term: Term) -> bool { let ret = env.register::().is_ok() && env.register::().is_ok() diff --git a/ts_elixir/native/ts_elixir/src/node_info.rs b/ts_elixir/native/ts_elixir/src/node_info.rs new file mode 100644 index 00000000..638a68d8 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/node_info.rs @@ -0,0 +1,43 @@ +use crate::{erl_ip::ErlIp, helpers::sockaddr_to_erl}; + +/// Info about a Tailscale peer. +#[derive(rustler::NifStruct)] +#[module = "Tailscale.NodeInfo"] +pub struct NodeInfo { + id: i64, + stable_id: String, + hostname: String, + tailnet: Option, + tags: Vec, + tailnet_addresses: Vec, + derp_region: Option, + node_key: String, + disco_key: Option, + machine_key: Option, + underlay_addresses: Vec<(ErlIp, u16)>, +} + +impl From for NodeInfo { + fn from(value: tailscale::NodeInfo) -> Self { + Self { + id: value.id, + stable_id: value.stable_id.0, + hostname: value.hostname, + tailnet: value.tailnet, + tags: value.tags, + tailnet_addresses: vec![ + ErlIp::from(value.tailnet_address.ipv4.addr()), + ErlIp::from(value.tailnet_address.ipv6.addr()), + ], + derp_region: value.derp_region.map(|x| x.0.get()), + node_key: value.node_key.to_string(), + disco_key: value.disco_key.as_ref().map(ToString::to_string), + machine_key: value.machine_key.as_ref().map(ToString::to_string), + underlay_addresses: value + .underlay_addresses + .into_iter() + .map(sockaddr_to_erl) + .collect(), + } + } +} diff --git a/ts_elixir/native/ts_elixir/src/tcp.rs b/ts_elixir/native/ts_elixir/src/tcp.rs index a539d34a..4ef8da9d 100644 --- a/ts_elixir/native/ts_elixir/src/tcp.rs +++ b/ts_elixir/native/ts_elixir/src/tcp.rs @@ -1,8 +1,10 @@ use std::sync::Arc; -use rustler::{Encoder, ResourceArc}; +use rustler::{Encoder, NifResult, ResourceArc}; -use crate::{IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_result, ip_from_erl, ok_arc}; +use crate::{ + IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_ip::ErlIp, erl_result, helpers::term_err, ok_arc, +}; pub(crate) struct TcpListener { inner: Arc, @@ -20,66 +22,66 @@ impl rustler::Resource for TcpStream {} #[rustler::nif(schedule = "DirtyIo")] fn tcp_listen( - env: rustler::Env, dev: ResourceArc, - addr: rustler::Term, + addr: IpOrSelf, port: u16, -) -> impl Encoder { +) -> NifResult { let dev = dev.inner.clone(); - let ip = IpOrSelf::new(addr); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = ip.ok_or("invalid ip addr")?.resolve(&dev).await?; - let sock = dev.tcp_listen((addr, port).into()).await?; + TOKIO_RUNTIME.block_on(async move { + let addr = addr.resolve(&dev).await?; + let sock = dev + .tcp_listen((addr, port).into()) + .await + .map_err(term_err)?; ok_arc(TcpListener { inner: Arc::new(sock), }) - }); - - erl_result(env, sock) + }) } #[rustler::nif] -fn tcp_listen_local_addr(env: rustler::Env, listener: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, listener.inner.local_addr()) +fn tcp_listen_local_addr(listener: ResourceArc) -> impl Encoder { + crate::sockaddr_to_erl(listener.inner.local_addr()) } #[rustler::nif(schedule = "DirtyIo")] fn tcp_connect( - env: rustler::Env<'_>, + env: rustler::Env, dev: ResourceArc, - addr: rustler::Term, + addr: ErlIp, port: u16, -) -> impl Encoder { - let addr = ip_from_erl(addr); +) -> NifResult { let dev = dev.inner.clone(); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = addr.ok_or("invalid ip addr")?; - let sock = dev.tcp_connect((addr, port).into()).await?; + TOKIO_RUNTIME + .block_on(async move { + let sock = dev + .tcp_connect((addr, port).into()) + .await + .map_err(term_err)?; - ok_arc(TcpStream { - inner: Arc::new(sock), + ok_arc(TcpStream { + inner: Arc::new(sock), + }) }) - }); - - erl_result(env, sock) + .map(|sock| sock.encode(env)) } #[rustler::nif(schedule = "DirtyIo")] -fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> impl Encoder { +fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> NifResult { let inner = sock.inner.clone(); - let sock = TOKIO_RUNTIME.block_on(async move { - let stream = inner.accept().await?; + TOKIO_RUNTIME + .block_on(async move { + let stream = inner.accept().await.map_err(term_err)?; - ok_arc(TcpStream { - inner: Arc::new(stream), + ok_arc(TcpStream { + inner: Arc::new(stream), + }) }) - }); - - erl_result(env, sock) + .map(|sock| sock.encode(env)) } #[rustler::nif(schedule = "DirtyIo")] @@ -93,7 +95,7 @@ fn tcp_send(env: rustler::Env, sock: ResourceArc, msg: Vec) -> ru } #[rustler::nif(schedule = "DirtyIo")] -fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> impl Encoder { +fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> NifResult { let inner = sock.inner.clone(); let buf = TOKIO_RUNTIME.block_on(async move { @@ -105,11 +107,11 @@ fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> impl Encoder { } #[rustler::nif] -fn tcp_local_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.local_addr()) +fn tcp_local_addr(sock: ResourceArc) -> impl Encoder { + crate::sockaddr_to_erl(sock.inner.local_addr()) } #[rustler::nif] -fn tcp_remote_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.remote_addr()) +fn tcp_remote_addr(sock: ResourceArc) -> impl Encoder { + crate::sockaddr_to_erl(sock.inner.remote_addr()) } diff --git a/ts_elixir/native/ts_elixir/src/udp.rs b/ts_elixir/native/ts_elixir/src/udp.rs index 21945328..73a5e869 100644 --- a/ts_elixir/native/ts_elixir/src/udp.rs +++ b/ts_elixir/native/ts_elixir/src/udp.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use rustler::{Binary, Encoder, ResourceArc, Term}; +use rustler::{Binary, Encoder, NifResult, ResourceArc, Term}; use crate::{ - Device, IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_result, ip_from_erl, ip_to_erl, ok_arc, + Device, IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_ip::ErlIp, helpers::term_err, ok_arc, }; pub struct UdpSocket { @@ -14,38 +14,39 @@ pub struct UdpSocket { impl rustler::Resource for UdpSocket {} #[rustler::nif(schedule = "DirtyIo")] -fn udp_bind(env: rustler::Env, dev: ResourceArc, ip: Term, port: u16) -> impl Encoder { +fn udp_bind( + env: rustler::Env, + dev: ResourceArc, + ip: IpOrSelf, + port: u16, +) -> NifResult { let dev = dev.inner.clone(); - let ip = IpOrSelf::new(ip); - let sock = TOKIO_RUNTIME.block_on(async move { - let addr = ip.ok_or("invalid ip addr")?.resolve(&dev).await?; - let sock = dev.udp_bind((addr, port).into()).await?; + TOKIO_RUNTIME + .block_on(async move { + let addr = ip.resolve(&dev).await?; + let sock = dev.udp_bind((addr, port).into()).await.map_err(term_err)?; - ok_arc(UdpSocket { - inner: Arc::new(sock), + ok_arc(UdpSocket { + inner: Arc::new(sock), + }) }) - }); - - erl_result(env, sock) + .map(|sock| sock.encode(env)) } #[rustler::nif(schedule = "DirtyIo")] fn udp_send<'env>( env: rustler::Env<'env>, sock: ResourceArc, - ip: Term, + ip: ErlIp, port: u16, msg: Binary, ) -> Term<'env> { - let addr = ip_from_erl(ip); let msg = msg.to_vec(); let sock = sock.inner.clone(); match TOKIO_RUNTIME.block_on(async move { - let addr = addr.ok_or("invalid ip addr")?; - - sock.send_to((addr, port).into(), &msg).await?; + sock.send_to((ip.0, port).into(), &msg).await?; Result::<_>::Ok(()) }) { @@ -55,22 +56,13 @@ fn udp_send<'env>( } #[rustler::nif(schedule = "DirtyIo")] -fn udp_recv(env: rustler::Env, sock: ResourceArc) -> Term { - let (who, msg) = match sock.inner.recv_from_bytes_blocking() { - Ok((who, msg)) => (who, msg), - Err(e) => return erl_result(env, Result::<()>::Err(e.into())), - }; +fn udp_recv(env: rustler::Env, sock: ResourceArc) -> NifResult { + let (who, msg) = sock.inner.recv_from_bytes_blocking().map_err(term_err)?; - ( - atoms::ok(), - ip_to_erl(env, who.ip()), - who.port(), - msg.to_vec(), - ) - .encode(env) + Ok((atoms::ok(), ErlIp(who.ip()), who.port(), msg.to_vec()).encode(env)) } #[rustler::nif] -fn udp_local_addr(env: rustler::Env, sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(env, sock.inner.local_addr()) +fn udp_local_addr(sock: ResourceArc) -> impl Encoder { + crate::sockaddr_to_erl(sock.inner.local_addr()) } From 8dd3376889cc37296ab15019ab9749d081a14031 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 24 Apr 2026 07:50:03 -0400 Subject: [PATCH 4/5] elixir: refactor to avoid use of dirtyio Potentially-blocking Rust-side calls now use message passing to respond to the caller. Elixir now has `Tailscale.Util.await`, which spawns a task that listens for the relevant response. Signed-off-by: Nathan Perry Change-Id: Iafe01153ca52b9f618c80666de5e6f186a6a6964 --- Cargo.lock | 8 + ts_elixir/lib/tailscale.ex | 27 +-- ts_elixir/lib/tailscale/native.ex | 89 ++++++++-- ts_elixir/lib/tailscale/tcp.ex | 6 +- ts_elixir/lib/tailscale/tcp/listener.ex | 4 +- ts_elixir/lib/tailscale/tcp/stream.ex | 8 +- ts_elixir/lib/tailscale/udp.ex | 8 +- ts_elixir/lib/tailscale/util.ex | 59 +++++++ ts_elixir/native/ts_elixir/Cargo.toml | 5 +- ts_elixir/native/ts_elixir/src/async_reply.rs | 160 ++++++++++++++++++ ts_elixir/native/ts_elixir/src/helpers.rs | 15 +- ts_elixir/native/ts_elixir/src/lib.rs | 131 +++++++------- ts_elixir/native/ts_elixir/src/tcp.rs | 93 +++++----- ts_elixir/native/ts_elixir/src/udp.rs | 69 ++++---- 14 files changed, 496 insertions(+), 186 deletions(-) create mode 100644 ts_elixir/lib/tailscale/util.ex create mode 100644 ts_elixir/native/ts_elixir/src/async_reply.rs diff --git a/Cargo.lock b/Cargo.lock index 6f34ad7d..720ec010 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3602,6 +3602,12 @@ dependencies = [ "url", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "target-lexicon" version = "0.13.5" @@ -4284,8 +4290,10 @@ dependencies = [ name = "ts_elixir" version = "0.3.0" dependencies = [ + "futures-util", "rustler", "tailscale", + "tap", "tokio", "tracing", "tracing-subscriber", diff --git a/ts_elixir/lib/tailscale.ex b/ts_elixir/lib/tailscale.ex index 3582e87a..71dcb47f 100644 --- a/ts_elixir/lib/tailscale.ex +++ b/ts_elixir/lib/tailscale.ex @@ -1,4 +1,6 @@ defmodule Tailscale do + require Tailscale.Util + @moduledoc """ Elixir bindings for the Tailscale Rust client. @@ -66,8 +68,8 @@ defmodule Tailscale do See `t:options/0` for details on available options. """ - def connect(key_file_path, options) when is_binary(key_file_path) do - case Tailscale.Native.load_key_file(key_file_path) do + def connect(key_file_path, options) when is_binary(key_file_path) and is_list(options) do + case Tailscale.Util.await(Tailscale.Native.load_key_file(key_file_path)) do {:ok, keys} -> Keyword.put(options, :keys, keys) |> connect() @@ -86,8 +88,10 @@ defmodule Tailscale do """ def connect(options \\ []) - def connect(options) when is_list(options), - do: :proplists.to_map(options) |> Tailscale.Native.connect() + def connect(options) when is_list(options) do + options = :proplists.to_map(options) + Tailscale.Util.await(Tailscale.Native.connect(options)) + end def connect(key_file_path) when is_binary(key_file_path), do: connect(key_file_path, []) @@ -97,7 +101,7 @@ defmodule Tailscale do Blocks until the address is available. """ - def ipv4_addr(dev), do: Tailscale.Native.ipv4_addr(dev) + def ipv4_addr(dev), do: Tailscale.Util.await(Tailscale.Native.ipv4_addr(dev)) @spec ipv6_addr(t()) :: {:ok, :inet.ip6_address()} | {:error, any()} @doc """ @@ -108,13 +112,13 @@ defmodule Tailscale do Note that this address is in `t::inet.ip6_address/0` format (16-bit segments), which may be difficult to read. See `:inet.ntoa/1` to format to a string. """ - def ipv6_addr(dev), do: Tailscale.Native.ipv6_addr(dev) + def ipv6_addr(dev), do: Tailscale.Util.await(Tailscale.Native.ipv6_addr(dev)) @spec self_node(t()) :: {:ok, Tailscale.NodeInfo.t()} | {:error, any()} @doc """ Get this node's `m:Tailscale.NodeInfo`. """ - defdelegate self_node(dev), to: Tailscale.Native + def self_node(dev), do: Tailscale.Util.await(Tailscale.Native.self_node(dev)) @spec peer_by_name(t(), String.t()) :: {:ok, Tailscale.NodeInfo.t() | nil} | {:error, any()} @doc """ @@ -123,7 +127,8 @@ defmodule Tailscale do Returns `{:ok, nil}` if there was no such peer, and `{:error, reason}` if the lookup encountered an error. """ - def peer_by_name(dev, name), do: Tailscale.Native.peer_by_name(dev, name) + def peer_by_name(dev, name), + do: Tailscale.Util.await(Tailscale.Native.peer_by_name(dev, name)) @spec peer_by_tailnet_ip(t(), Tailscale.ip_addr()) :: {:ok, Tailscale.NodeInfo.t() | nil} | {:error, any()} @@ -132,12 +137,14 @@ defmodule Tailscale do Returns `{:ok, nil}` if there was no such peer. `:error` if the lookup encountered an error. """ - defdelegate peer_by_tailnet_ip(dev, ip), to: Tailscale.Native + def peer_by_tailnet_ip(dev, ip), + do: Tailscale.Util.await(Tailscale.Native.peer_by_tailnet_ip(dev, ip)) @spec peers_with_route(t(), Tailscale.ip_addr()) :: {:ok, [Tailscale.NodeInfo.t()]} | {:error, any()} @doc """ Retrieve the most narrow set of peers that accept packets for the specified IP. """ - defdelegate peers_with_route(dev, ip), to: Tailscale.Native + def peers_with_route(dev, ip), + do: Tailscale.Util.await(Tailscale.Native.peers_with_route(dev, ip)) end diff --git a/ts_elixir/lib/tailscale/native.ex b/ts_elixir/lib/tailscale/native.ex index abd7dfdd..76ec50cb 100644 --- a/ts_elixir/lib/tailscale/native.ex +++ b/ts_elixir/lib/tailscale/native.ex @@ -32,6 +32,31 @@ defmodule Tailscale.Native do """ @opaque tcp_stream :: reference() + @typedoc """ + NIFs provided here may have asynchronous effects that would typically block and require the use of + the DirtyIO scheduler. This is undesirable as we may have a large number of concurrent calls into + the NIFs, which could exhaust the DirtyIO thread pool. Instead, we use message passing on the Rust + side to send replies back into the BEAM. Functions that use this model return `async_reply` + without blocking. The `:async` case means the reply will be sent asynchronously using a message of + the format `{:tailscale, REF, PAYLOAD}`, where `REF` is the reference associated with the `:async` + response, guaranteed unique per call. + + The `:error` response means that an error was encountered before dispatching the asynchronous + call. + + The `:nif_panic` response means that the NIF panicked during execution; the second parameter is + the reason for the panic (if given). + + `{:raise, TERM}` means `TERM` should be raised as an exception. + + `m:Tailscale.Util` has helpers for decoding messages of this form. + """ + @type async_reply() :: + {:async, reference()} + | {:error, any()} + | {:nif_panic, String.t() | {}} + | {:raise, any()} + defp err, do: :erlang.nif_error(:nif_not_loaded) @doc """ @@ -39,7 +64,7 @@ defmodule Tailscale.Native do See `t:Tailscale.options/0` for details on what options are supported. """ - @spec connect(%{}) :: {:ok, device()} | {:error, any()} + @spec connect(%{}) :: async_reply() def connect(_opts), do: err() @doc """ @@ -51,7 +76,7 @@ defmodule Tailscale.Native do - `port`: the port to which the socket should bind. """ @spec udp_bind(device(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: - {:ok, udp_socket()} | {:error, any()} + async_reply() def udp_bind(_dev, _addr, _port), do: err() @doc """ @@ -65,14 +90,14 @@ defmodule Tailscale.Native do - `msg`: the packet to send. """ @spec udp_send(udp_socket(), Tailscale.ip_addr(), :inet.port_number(), binary()) :: - :ok | {:error, any()} + async_reply() def udp_send(_sock, _ip, _port, _msg), do: err() @doc """ Receive an incoming UDP packet on the given socket. """ @spec udp_recv(udp_socket()) :: - {:ok, :inet.ip_address(), :inet.port_number(), binary()} | {:error, any()} + async_reply() def udp_recv(_sock), do: err() @doc """ @@ -92,7 +117,7 @@ defmodule Tailscale.Native do Start a TCP listener on the given device, address, and port. """ @spec tcp_listen(device(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: - {:ok, tcp_listener()} | {:error, any()} + async_reply() def tcp_listen(_dev, _addr, _port), do: err() @doc """ @@ -105,13 +130,13 @@ defmodule Tailscale.Native do Connect to the given TCP endpoint using the given device. """ @spec tcp_connect(device(), Tailscale.ip_addr(), :inet.port_number()) :: - {:ok, tcp_stream()} | {:error, any()} + async_reply() def tcp_connect(_dev, _addr, _port), do: err() @doc """ Accept an incoming TCP connection. Blocks until one is available. """ - @spec tcp_accept(tcp_listener()) :: {:ok, tcp_stream()} | {:error, any()} + @spec tcp_accept(tcp_listener()) :: async_reply() def tcp_accept(_listener), do: err() @doc """ @@ -120,13 +145,13 @@ defmodule Tailscale.Native do Returns the number of bytes actually written to the remote. """ - @spec tcp_send(tcp_stream(), binary()) :: {:ok, integer()} | {:error, any()} + @spec tcp_send(tcp_stream(), binary()) :: async_reply() def tcp_send(_stream, _msg), do: err() @doc """ Receive incoming data from the tcp socket, blocking until at least one byte can be received. """ - @spec tcp_recv(tcp_stream()) :: {:ok, binary()} | {:error, any()} + @spec tcp_recv(tcp_stream()) :: async_reply() def tcp_recv(_stream), do: err() @doc """ @@ -146,7 +171,7 @@ defmodule Tailscale.Native do Blocks until the device is connected and gets its address from control. """ - @spec ipv4_addr(device()) :: {:ok, :inet.ip4_address()} | {:error, any()} + @spec ipv4_addr(device()) :: async_reply() def ipv4_addr(_dev), do: err() @doc """ @@ -154,36 +179,68 @@ defmodule Tailscale.Native do Blocks until the device is connected and gets its address from control. """ - @spec ipv6_addr(device()) :: {:ok, :inet.ip6_address()} | {:error, any()} + @spec ipv6_addr(device()) :: async_reply() def ipv6_addr(_dev), do: err() @doc """ Retrieve a peer by name. """ - @spec peer_by_name(device(), String.t()) :: {:ok, %{} | nil} | {:error, any()} + @spec peer_by_name(device(), String.t()) :: async_reply() def peer_by_name(_dev, _name), do: err() @doc """ Retrieve this node's info """ - @spec self_node(device()) :: {:ok, %{}} | {:error, any()} + @spec self_node(device()) :: async_reply() def self_node(_dev), do: err() @doc """ Retrieve a peer by its tailnet IP. """ - @spec peer_by_tailnet_ip(device(), Tailscale.ip_addr()) :: {:ok, %{} | nil} | {:error, any()} + @spec peer_by_tailnet_ip(device(), Tailscale.ip_addr()) :: async_reply() def peer_by_tailnet_ip(_dev, _ip), do: err() @doc """ Retrieve the most narrow set of peers that accept packets for the specified IP. """ - @spec peers_with_route(device(), Tailscale.ip_addr()) :: {:ok, [%{}]} | {:error, any()} + @spec peers_with_route(device(), Tailscale.ip_addr()) :: async_reply() def peers_with_route(_dev, _ip), do: err() @doc """ Load key state from the specified path, generating a new state if the file doesn't exist. """ - @spec load_key_file(String.t()) :: {:ok, Tailscale.Keystate.t()} | {:error, any()} + @spec load_key_file(String.t()) :: async_reply() def load_key_file(_path), do: err() + + @doc """ + Raise a `:badarg` exception. + """ + @spec raise_badarg() :: nil + def raise_badarg(), do: err() + + if @testing_nifs do + @doc """ + DEV ONLY: trigger an async panic in the Rust code with the given message (if provided). + """ + @spec async_panic(String.t() | nil) :: async_reply() + def async_panic(_msg \\ nil), do: err() + + @doc """ + DEV ONLY: trigger a raised exception in the Rust code with the given message. + """ + @spec async_raise(String.t(), boolean()) :: async_reply() + def async_raise(_msg, _atom \\ false), do: err() + + @doc """ + DEV ONLY: trigger an asynchronous error in the Rust code with the given message. + """ + @spec async_error(String.t(), boolean()) :: async_reply() + def async_error(_msg, _atom \\ false), do: err() + + @doc """ + DEV ONLY: trigger an asynchronous `:badarg` in the Rust code with the given message. + """ + @spec async_badarg() :: async_reply() + def async_badarg(), do: err() + end end diff --git a/ts_elixir/lib/tailscale/tcp.ex b/ts_elixir/lib/tailscale/tcp.ex index a51a378d..3d63bc37 100644 --- a/ts_elixir/lib/tailscale/tcp.ex +++ b/ts_elixir/lib/tailscale/tcp.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Tcp do + require Tailscale.Util + @moduledoc """ Functionality to create tailscale TCP sockets. @@ -19,7 +21,7 @@ defmodule Tailscale.Tcp do @spec listen(Tailscale.t(), Tailscale.ip_addr() | :ip4 | :ip6, :inet.port_number()) :: {:ok, Tailscale.Tcp.Listener.t()} | {:error, any()} def listen(dev, addr, port) do - Tailscale.Native.tcp_listen(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.tcp_listen(dev, addr, port)) end @doc """ @@ -28,6 +30,6 @@ defmodule Tailscale.Tcp do @spec connect(Tailscale.t(), Tailscale.ip_addr(), :inet.port_number()) :: {:ok, Tailscale.Tcp.Stream.t()} | {:error, any()} def connect(dev, addr, port) do - Tailscale.Native.tcp_connect(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.tcp_connect(dev, addr, port)) end end diff --git a/ts_elixir/lib/tailscale/tcp/listener.ex b/ts_elixir/lib/tailscale/tcp/listener.ex index c26da408..d15889a9 100644 --- a/ts_elixir/lib/tailscale/tcp/listener.ex +++ b/ts_elixir/lib/tailscale/tcp/listener.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Tcp.Listener do + require Tailscale.Util + @moduledoc """ Tailscale TCP listening socket functionality. """ @@ -15,7 +17,7 @@ defmodule Tailscale.Tcp.Listener do Blocks until a connection is ready. """ def accept(res) do - Tailscale.Native.tcp_accept(res) + Tailscale.Util.await(Tailscale.Native.tcp_accept(res)) end @doc """ diff --git a/ts_elixir/lib/tailscale/tcp/stream.ex b/ts_elixir/lib/tailscale/tcp/stream.ex index 189ee9f3..58184574 100644 --- a/ts_elixir/lib/tailscale/tcp/stream.ex +++ b/ts_elixir/lib/tailscale/tcp/stream.ex @@ -3,6 +3,8 @@ defmodule Tailscale.Tcp.Stream do Tailscale TCP sockets (connected). """ + require Tailscale.Util + @typedoc """ A handle to a TCP stream (connected socket). """ @@ -15,7 +17,7 @@ defmodule Tailscale.Tcp.Stream do Returns the number of bytes actually sent. """ def send(res, msg) do - Tailscale.Native.tcp_send(res, msg) + Tailscale.Util.await(Tailscale.Native.tcp_send(res, msg)) end @spec send_all(t(), binary()) :: :ok | {:error, any()} @@ -27,7 +29,7 @@ defmodule Tailscale.Tcp.Stream do case Tailscale.Tcp.Stream.send(res, msg) do {:ok, ^len} -> :ok - {:ok, n} -> Tailscale.Tcp.Stream.send_all(res, binary_slice(msg, n..len)) + {:ok, n} -> send_all(res, binary_slice(msg, n..len)) err -> err end end @@ -37,7 +39,7 @@ defmodule Tailscale.Tcp.Stream do Receive data from the TCP socket, blocking until at least one byte can be received. """ def recv(res) do - Tailscale.Native.tcp_recv(res) + Tailscale.Util.await(Tailscale.Native.tcp_recv(res)) end @spec local_addr(t()) :: {:inet.ip_address(), :inet.port_number()} diff --git a/ts_elixir/lib/tailscale/udp.ex b/ts_elixir/lib/tailscale/udp.ex index 4d6a5993..d97720c3 100644 --- a/ts_elixir/lib/tailscale/udp.ex +++ b/ts_elixir/lib/tailscale/udp.ex @@ -1,4 +1,6 @@ defmodule Tailscale.Udp do + require Tailscale.Util + @moduledoc """ Tailscale UDP sockets. """ @@ -21,7 +23,7 @@ defmodule Tailscale.Udp do - `port`: the port number to bind. """ def bind(dev, addr, port) do - Tailscale.Native.udp_bind(dev, addr, port) + Tailscale.Util.await(Tailscale.Native.udp_bind(dev, addr, port)) end @spec send(t(), Tailscale.ip_addr(), :inet.port_number(), binary()) :: :ok | {:error, any()} @@ -37,7 +39,7 @@ defmodule Tailscale.Udp do - `payload`: the message payload. """ def send(sock, remote, port, payload) do - Tailscale.Native.udp_send(sock, remote, port, payload) + Tailscale.Util.await(Tailscale.Native.udp_send(sock, remote, port, payload)) end @spec recv(t()) :: {:ok, Tailscale.ip_addr(), :inet.port_number(), binary()} | {:error, any()} @@ -45,7 +47,7 @@ defmodule Tailscale.Udp do Receive a packet from the socket, blocking until one is ready. """ def recv(sock) do - Tailscale.Native.udp_recv(sock) + Tailscale.Util.await(Tailscale.Native.udp_recv(sock)) end @doc """ diff --git a/ts_elixir/lib/tailscale/util.ex b/ts_elixir/lib/tailscale/util.ex new file mode 100644 index 00000000..33e4ed37 --- /dev/null +++ b/ts_elixir/lib/tailscale/util.ex @@ -0,0 +1,59 @@ +defmodule Tailscale.Util do + @moduledoc false + # Internal utilities. + + @doc """ + Helper to await a Rust-side-async function that responds via message passing. + + Assumes the callee `block` returns the `:async` branch of `t:Tailscale.Native.async_reply/0`. Any + other response is returned verbatim, assumed to be an error. + """ + defmacro await(block, timeout \\ :infinity) do + quote do + Task.async(fn -> + Tailscale.Util.await_local(unquote(block), :infinity) + end) + |> Task.await(unquote(timeout)) + end + end + + @doc """ + Helper to await a Rust-side-async function that responds via message passing. + + Assumes the callee `block` returns the `:async` branch of `t:Tailscale.Native.async_reply/0`. Any + other response is returned verbatim, assumed to be an error. + + This macro (unlike `Tailscale.Util.await/2`) awaits a response message in the current process + without spawning a `m:Task`. This may be desirable to avoid the slight overhead of spawning a new + process, but may not be preferred if this process's mailbox is likely to be busy. + """ + defmacro await_local(block, timeout \\ :infinity) do + quote do + case unquote(block) do + {:async, ref} -> + receive do + {{:tailscale, ^ref}, result} -> result + after + unquote(timeout) -> + {:error, :timeout} + end + + other -> + other + end + |> Tailscale.Util.normalize_result() + end + end + + @doc """ + Normalize an async result to a standard Elixir-shaped return. + """ + def normalize_result({:ok, _} = result), do: normalize_tuple(result) + def normalize_result({:nif_panic, _} = result), do: {:error, normalize_tuple(result)} + def normalize_result({:raise, :badarg}), do: Tailscale.Native.raise_badarg() + def normalize_result({:raise, t}), do: raise(t) + def normalize_result(otherwise), do: otherwise + + defp normalize_tuple({a, {}}), do: a + defp normalize_tuple(a), do: a +end diff --git a/ts_elixir/native/ts_elixir/Cargo.toml b/ts_elixir/native/ts_elixir/Cargo.toml index 94d7850c..8b6c8028 100644 --- a/ts_elixir/native/ts_elixir/Cargo.toml +++ b/ts_elixir/native/ts_elixir/Cargo.toml @@ -11,10 +11,11 @@ license.workspace = true rust-version.workspace = true [dependencies] -rustler = "0.37.2" - tailscale = { workspace = true } +futures-util.workspace = true +rustler = "0.37.2" +tap = "1.0" tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/ts_elixir/native/ts_elixir/src/async_reply.rs b/ts_elixir/native/ts_elixir/src/async_reply.rs new file mode 100644 index 00000000..15f276d7 --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/async_reply.rs @@ -0,0 +1,160 @@ +//! Facilities for sending asynchronous responses from NIFs. +//! +//! The motivation is that the Erlang DirtyIO scheduler is a thread-pool with inherently +//! limited concurrency (n = threads in the pool), and our NIFs will have to block a whole +//! thread on that pool while they're doing anything, even if it's asynchronous work running +//! on tokio. +//! +//! To avoid that, we adopt a more Erlang/Elixir-oriented approach and respond via message +//! passing. This doesn't block the BEAM at all, specifically because we're essentially +//! interacting with its event loop directly, rather than through the opaque abstraction of +//! a function running on a foreign thread. +//! +//! Our NIFs with async work to do return immediately with `{:async, REF}`, where `REF` is a +//! BEAM reference that uniquely identifies the function invocation. In the background, we +//! do whatever we need to and reply eventually (to the original caller's pid) with a +//! message holding the result of the function call and the original `REF` for correlation. + +use std::panic::AssertUnwindSafe; + +use futures_util::FutureExt; +use rustler::{Encoder, NifResult, OwnedEnv, Term}; + +use crate::{TOKIO_RUNTIME, atoms}; + +pub type AsyncReply<'a> = (rustler::Atom, rustler::Reference<'a>); + +/// Spiritual reimplementation of [`rustler::thread::spawn`] for futures. +/// +/// `fut` is executed in a tokio task, and the result is passed to `post`, which encodes a +/// result to pass back to `pid` as a message. If `fut` or `post` panics, `on_panic` is +/// invoked instead with the encoded reason for the panic, and the returned term is passed +/// back to the calling `pid`. +/// +/// Returns a [`rustler::Reference`] which uniquely identifies this particular spawn call. +/// The same reference is also provided to `post` and `on_panic`; they may or may not choose +/// to make use of it for correlation. +/// +/// NB: this function intentionally does not encode any specifics about the response format. +/// Conventionally, our NIFs respond with `{{:tailscale, REF}, PAYLOAD}`, but this function +/// is general-purpose and doesn't make that assumption: it responds with whatever you +/// tell it to, in the interest of separating concerns. The pieces specific to our current +/// async reply convention are encoded in [`reply_async`] and [`try_reply_async`]. +pub fn spawn( + env: rustler::Env, + fut: F, + post: Post, + on_panic: OnPanic, +) -> rustler::Reference +where + F: Future + Send + 'static, + F::Output: std::panic::UnwindSafe, + Post: for<'env> FnOnce(rustler::Env<'env>, rustler::Reference<'env>, F::Output) -> Term<'env> + + Send + + std::panic::UnwindSafe + + 'static, + OnPanic: for<'env> FnOnce(rustler::Env<'env>, rustler::Reference<'env>, Term<'env>) -> Term<'env> + + Send + + 'static, +{ + let pid = env.pid(); + let ref_ = env.make_ref(); + + let mut env = OwnedEnv::new(); + let saved_ref = env.save(ref_); + + TOKIO_RUNTIME.spawn(async move { + let result = AssertUnwindSafe(fut).catch_unwind().await.and_then(|result| { + std::panic::catch_unwind(|| { + if env.run(|env| { + let ref_ = saved_ref.load(env).decode::().unwrap(); + let value = post(env, ref_, result).encode(env); + + env.send(&pid, value) + }).is_err() { + tracing::error!(target_pid = ?pid.as_c_arg(), "failed sending success reply from spawn, process dead?"); + } + }) + }); + + if let Err(err) = result { + let send_result = env.send_and_clear(&pid, move |env| { + let ref_ = saved_ref.load(env).decode::().unwrap(); + + let reason = if let Some(string) = err.downcast_ref::() { + string.encode(env) + } else if let Some(&s) = err.downcast_ref::<&'static str>() { + s.encode(env) + } else { + ().encode(env) + }; + + on_panic(env, ref_, reason) + }); + + if send_result.is_err() { + tracing::error!(target_pid = ?pid.as_c_arg(), "failed sending panic reply from spawn, process dead?"); + } + } + }); + + ref_ +} + +/// Convenience wrapper for [`spawn`] when the return type is [`crate::Result`], +/// automatically converting the response to a reply +/// `{:ok, TERM} | {:error, TERM} | {:nif_panic, TERM} | {:raise | TERM}` wrapped in +/// `{:tailscale, ref, REPLY}`. +pub fn try_reply_async(env: rustler::Env, fut: F) -> AsyncReply +where + F: Future> + Send + 'static, + T: Encoder, +{ + let ref_ = spawn( + env, + async move { AssertUnwindSafe(fut.await) }, + move |env, ref_, t| { + let resp = match t.0 { + Ok(val) => (atoms::ok(), val).encode(env), + Err(e) => encode_async_err(env, e), + }; + + async_resp(ref_, resp).encode(env) + }, + move |env, ref_, reason| async_resp(ref_, (atoms::nif_panic(), reason)).encode(env), + ); + + (atoms::async_(), ref_) +} + +#[rustler::nif] +fn raise_badarg() -> NifResult<()> { + Err(rustler::Error::BadArg) +} + +pub fn async_resp<'r, T>(ref_: rustler::Reference<'r>, value: T) -> (AsyncReply<'r>, T) { + ((atoms::tailscale(), ref_), value) +} + +/// Encode the given [`rustler::Error`] as a [`Term`]. +/// +/// This is needed because [`rustler::Error`] typically expects to be returned from a NIF, +/// where it can directly raise exceptions on the [`Env`]. We don't want to do that here, we +/// want to forward the exception to raise through message passing. On the Elixir side, +/// `Tailscale.Util.normalize_result` handles converting the value into the correct form +/// (`{:error, TERM}` or a raised exception). +fn encode_async_err(env: rustler::Env, err: rustler::Error) -> Term { + match err { + rustler::Error::Term(b) => (atoms::error(), b.encode(env)).encode(env), + rustler::Error::Atom(a) => match rustler::Atom::from_str(env, a) { + Ok(atom) => env.error_tuple(atom), + Err(_e) => (atoms::raise(), atoms::badarg()).encode(env), + }, + rustler::Error::BadArg => (atoms::raise(), atoms::badarg()).encode(env), + rustler::Error::RaiseAtom(atom) => match rustler::Atom::from_str(env, atom) { + Ok(atom) => (atoms::raise(), atom).encode(env), + Err(_e) => (atoms::raise(), atoms::badarg()).encode(env), + }, + rustler::Error::RaiseTerm(t) => (atoms::raise(), t.encode(env)).encode(env), + } +} diff --git a/ts_elixir/native/ts_elixir/src/helpers.rs b/ts_elixir/native/ts_elixir/src/helpers.rs index b9a324d6..80d4d393 100644 --- a/ts_elixir/native/ts_elixir/src/helpers.rs +++ b/ts_elixir/native/ts_elixir/src/helpers.rs @@ -1,10 +1,8 @@ -use std::{error::Error, fmt::Display, net::SocketAddr}; +use std::{fmt::Display, net::SocketAddr}; -use rustler::{Encoder, NifResult, ResourceArc, Term}; +use rustler::{NifResult, ResourceArc}; -use crate::{atoms, erl_ip::ErlIp}; - -pub type Result = std::result::Result>; +use crate::erl_ip::ErlIp; /// Wrap the given [`rustler::Resource`] in a [`ResourceArc`] inside a [`NifResult`]. pub fn ok_arc(t: T) -> NifResult> @@ -22,10 +20,3 @@ pub fn term_err(e: impl Display) -> rustler::Error { pub fn sockaddr_to_erl(addr: SocketAddr) -> (ErlIp, u16) { (ErlIp(addr.ip()), addr.port()) } - -pub fn erl_result(env: rustler::Env, r: Result) -> NifResult { - match r { - Ok(t) => Ok((atoms::ok(), t).encode(env)), - Err(e) => Err(term_err(e)), - } -} diff --git a/ts_elixir/native/ts_elixir/src/lib.rs b/ts_elixir/native/ts_elixir/src/lib.rs index d0a52219..52820651 100644 --- a/ts_elixir/native/ts_elixir/src/lib.rs +++ b/ts_elixir/native/ts_elixir/src/lib.rs @@ -5,9 +5,11 @@ use std::{ sync::{Arc, LazyLock, Once}, }; -use rustler::{Encoder, NifResult, ResourceArc, Term}; +use rustler::{NifResult, ResourceArc, Term}; +use tap::Pipe; use tracing::level_filters::LevelFilter; +mod async_reply; mod config; mod erl_ip; mod helpers; @@ -16,23 +18,28 @@ mod node_info; mod tcp; mod udp; +use async_reply::{AsyncReply, try_reply_async}; use config::Keystate; use erl_ip::ErlIp; -use helpers::{Result, ok_arc, sockaddr_to_erl, term_err}; +use helpers::{ok_arc, sockaddr_to_erl, term_err}; use ip_or_self::IpOrSelf; use node_info::NodeInfo; use tcp::{TcpListener, TcpStream}; use udp::UdpSocket; -use crate::helpers::erl_result; - mod atoms { rustler::atoms! { ok, + async_ = "async", error, + nif_panic, + badarg, + raise, ip4, ip6, + + tailscale, } } @@ -69,14 +76,14 @@ fn start_tracing() { }); } -#[rustler::nif(schedule = "DirtyIo")] +#[rustler::nif] fn connect<'env>( env: rustler::Env<'env>, opts: HashMap>, -) -> NifResult> { +) -> NifResult> { let (config, auth_key) = config::config_from_erl(&opts)?; - let dev = TOKIO_RUNTIME.block_on(async move { + try_reply_async(env, async move { let dev = tailscale::Device::new(&config, auth_key) .await .map_err(term_err)?; @@ -84,84 +91,92 @@ fn connect<'env>( ok_arc(Device { inner: Arc::new(dev), }) - }); - - dev.map(|d| d.encode(env)) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn load_key_file(env: rustler::Env, path: &str) -> NifResult { - let result = TOKIO_RUNTIME - .block_on(tailscale::config::load_key_file(path, Default::default())) - .map(Keystate::from) - .map_err(Into::into); - - erl_result(env, result) +#[rustler::nif] +fn load_key_file(env: rustler::Env, path: String) -> AsyncReply { + try_reply_async(env, async move { + tailscale::config::load_key_file(path, Default::default()) + .await + .map(Keystate::from) + .map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> NifResult { +#[rustler::nif] +fn ipv4_addr(env: rustler::Env, dev: ResourceArc) -> AsyncReply { let dev = dev.inner.clone(); - let addr = TOKIO_RUNTIME.block_on(dev.ipv4_addr()); - erl_result(env, addr.map(|ip| ErlIp(ip.into())).map_err(Into::into)) + try_reply_async(env, async move { + dev.ipv4_addr().await.map(ErlIp::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn ipv6_addr(dev: ResourceArc) -> NifResult { +#[rustler::nif] +fn ipv6_addr(env: rustler::Env<'_>, dev: ResourceArc) -> AsyncReply<'_> { let dev = dev.inner.clone(); - TOKIO_RUNTIME - .block_on(dev.ipv6_addr()) - .map(ErlIp::from) - .map_err(term_err) + try_reply_async(env, async move { + dev.ipv6_addr().await.map(ErlIp::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn peer_by_name(env: rustler::Env<'_>, dev: ResourceArc, name: &str) -> impl Encoder { +#[rustler::nif] +fn peer_by_name<'e>(env: rustler::Env<'e>, dev: ResourceArc, name: &str) -> AsyncReply<'e> { let dev = dev.inner.clone(); let name = name.to_owned(); - match TOKIO_RUNTIME.block_on(async move { dev.peer_by_name(&name).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from(peer)).encode(env), - } + try_reply_async(env, async move { + dev.peer_by_name(&name) + .await + .map(|opt| opt.map(NodeInfo::from)) + .map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn self_node(env: rustler::Env<'_>, dev: ResourceArc) -> impl Encoder { +#[rustler::nif] +fn self_node(env: rustler::Env<'_>, dev: ResourceArc) -> AsyncReply<'_> { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { dev.self_node().await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(peer) => (atoms::ok(), NodeInfo::from(peer)).encode(env), - } + try_reply_async(env, async move { + dev.self_node().await.map(NodeInfo::from).map_err(term_err) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn peer_by_tailnet_ip(env: rustler::Env<'_>, dev: ResourceArc, ip: ErlIp) -> impl Encoder { +#[rustler::nif] +fn peer_by_tailnet_ip<'e>( + env: rustler::Env<'e>, + dev: ResourceArc, + ip: ErlIp, +) -> NifResult> { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { dev.peer_by_tailnet_ip(ip.0).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(None) => (atoms::ok(), Option::<()>::None).encode(env), - Ok(Some(peer)) => (atoms::ok(), NodeInfo::from(peer)).encode(env), - } + try_reply_async(env, async move { + dev.peer_by_tailnet_ip(ip.into()) + .await + .map(|x| x.map(NodeInfo::from)) + .map_err(term_err) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn peers_with_route(env: rustler::Env<'_>, dev: ResourceArc, ip: ErlIp) -> impl Encoder { +#[rustler::nif] +fn peers_with_route<'e>( + env: rustler::Env<'e>, + dev: ResourceArc, + ip: ErlIp, +) -> NifResult> { let dev = dev.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { dev.peers_with_route(ip.0).await }) { - Err(e) => (atoms::error(), e.to_string()).encode(env), - Ok(peers) => ( - atoms::ok(), - peers.into_iter().map(NodeInfo::from).collect::>(), - ) - .encode(env), - } + try_reply_async(env, async move { + dev.peers_with_route(ip.into()) + .await + .map(|peers| peers.into_iter().map(NodeInfo::from).collect::>()) + .map_err(term_err) + }) + .pipe(Ok) } fn load(env: rustler::Env, _term: Term) -> bool { diff --git a/ts_elixir/native/ts_elixir/src/tcp.rs b/ts_elixir/native/ts_elixir/src/tcp.rs index 4ef8da9d..5db2f0a7 100644 --- a/ts_elixir/native/ts_elixir/src/tcp.rs +++ b/ts_elixir/native/ts_elixir/src/tcp.rs @@ -1,9 +1,11 @@ use std::sync::Arc; use rustler::{Encoder, NifResult, ResourceArc}; +use tap::Pipe; use crate::{ - IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_ip::ErlIp, erl_result, helpers::term_err, ok_arc, + AsyncReply, IpOrSelf, erl_ip::ErlIp, helpers::term_err, ok_arc, sockaddr_to_erl, + try_reply_async, }; pub(crate) struct TcpListener { @@ -20,16 +22,17 @@ impl rustler::Resource for TcpListener {} #[rustler::resource_impl] impl rustler::Resource for TcpStream {} -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_listen( +#[rustler::nif] +fn tcp_listen<'e>( + env: rustler::Env<'e>, dev: ResourceArc, - addr: IpOrSelf, + ip: IpOrSelf, port: u16, -) -> NifResult { +) -> NifResult> { let dev = dev.inner.clone(); - TOKIO_RUNTIME.block_on(async move { - let addr = addr.resolve(&dev).await?; + try_reply_async(env, async move { + let addr = ip.resolve(&dev).await?; let sock = dev .tcp_listen((addr, port).into()) .await @@ -39,79 +42,75 @@ fn tcp_listen( inner: Arc::new(sock), }) }) + .pipe(Ok) } #[rustler::nif] fn tcp_listen_local_addr(listener: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(listener.inner.local_addr()) + sockaddr_to_erl(listener.inner.local_addr()) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_connect( - env: rustler::Env, +#[rustler::nif] +fn tcp_connect<'e>( + env: rustler::Env<'e>, dev: ResourceArc, addr: ErlIp, port: u16, -) -> NifResult { +) -> NifResult> { let dev = dev.inner.clone(); - TOKIO_RUNTIME - .block_on(async move { - let sock = dev - .tcp_connect((addr, port).into()) - .await - .map_err(term_err)?; + try_reply_async(env, async move { + let sock = dev + .tcp_connect((addr, port).into()) + .await + .map_err(term_err)?; - ok_arc(TcpStream { - inner: Arc::new(sock), - }) + ok_arc(TcpStream { + inner: Arc::new(sock), }) - .map(|sock| sock.encode(env)) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> NifResult { +#[rustler::nif] +fn tcp_accept(env: rustler::Env<'_>, sock: ResourceArc) -> AsyncReply<'_> { let inner = sock.inner.clone(); - TOKIO_RUNTIME - .block_on(async move { - let stream = inner.accept().await.map_err(term_err)?; + try_reply_async(env, async move { + let stream = inner.accept().await.map_err(term_err)?; - ok_arc(TcpStream { - inner: Arc::new(stream), - }) + ok_arc(TcpStream { + inner: Arc::new(stream), }) - .map(|sock| sock.encode(env)) + }) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_send(env: rustler::Env, sock: ResourceArc, msg: Vec) -> rustler::Term { +#[rustler::nif] +fn tcp_send(env: rustler::Env, sock: ResourceArc, msg: Vec) -> AsyncReply { let inner = sock.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { inner.send(&msg).await }) { - Ok(n) => (atoms::ok(), n).encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } + try_reply_async(env, async move { inner.send(&msg).await.map_err(term_err) }) } -#[rustler::nif(schedule = "DirtyIo")] -fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> NifResult { +#[rustler::nif] +fn tcp_recv(env: rustler::Env, sock: ResourceArc) -> AsyncReply { let inner = sock.inner.clone(); - let buf = TOKIO_RUNTIME.block_on(async move { - let buf = inner.recv_bytes().await?; - Result::<_>::Ok(buf.to_vec()) - }); - - erl_result(env, buf) + try_reply_async(env, async move { + inner + .recv_bytes() + .await + .map(|b| b.to_vec()) + .map_err(term_err) + }) } #[rustler::nif] fn tcp_local_addr(sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(sock.inner.local_addr()) + sockaddr_to_erl(sock.inner.local_addr()) } #[rustler::nif] fn tcp_remote_addr(sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(sock.inner.remote_addr()) + sockaddr_to_erl(sock.inner.remote_addr()) } diff --git a/ts_elixir/native/ts_elixir/src/udp.rs b/ts_elixir/native/ts_elixir/src/udp.rs index 73a5e869..f2bf5e3f 100644 --- a/ts_elixir/native/ts_elixir/src/udp.rs +++ b/ts_elixir/native/ts_elixir/src/udp.rs @@ -1,9 +1,11 @@ use std::sync::Arc; -use rustler::{Binary, Encoder, NifResult, ResourceArc, Term}; +use rustler::{Binary, Encoder, NifResult, ResourceArc}; +use tap::Pipe; use crate::{ - Device, IpOrSelf, Result, TOKIO_RUNTIME, atoms, erl_ip::ErlIp, helpers::term_err, ok_arc, + AsyncReply, Device, IpOrSelf, erl_ip::ErlIp, helpers::term_err, ok_arc, sockaddr_to_erl, + try_reply_async, }; pub struct UdpSocket { @@ -13,56 +15,59 @@ pub struct UdpSocket { #[rustler::resource_impl] impl rustler::Resource for UdpSocket {} -#[rustler::nif(schedule = "DirtyIo")] -fn udp_bind( - env: rustler::Env, +#[rustler::nif] +fn udp_bind<'e>( + env: rustler::Env<'e>, dev: ResourceArc, ip: IpOrSelf, port: u16, -) -> NifResult { +) -> NifResult> { let dev = dev.inner.clone(); - TOKIO_RUNTIME - .block_on(async move { - let addr = ip.resolve(&dev).await?; - let sock = dev.udp_bind((addr, port).into()).await.map_err(term_err)?; + try_reply_async(env, async move { + let addr = ip.resolve(&dev).await?; + let sock = dev.udp_bind((addr, port).into()).await.map_err(term_err)?; - ok_arc(UdpSocket { - inner: Arc::new(sock), - }) + ok_arc(UdpSocket { + inner: Arc::new(sock), }) - .map(|sock| sock.encode(env)) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn udp_send<'env>( - env: rustler::Env<'env>, +#[rustler::nif] +fn udp_send<'e>( + env: rustler::Env<'e>, sock: ResourceArc, - ip: ErlIp, + addr: ErlIp, port: u16, msg: Binary, -) -> Term<'env> { +) -> NifResult> { let msg = msg.to_vec(); let sock = sock.inner.clone(); - match TOKIO_RUNTIME.block_on(async move { - sock.send_to((ip.0, port).into(), &msg).await?; - - Result::<_>::Ok(()) - }) { - Ok(_) => atoms::ok().encode(env), - Err(e) => (atoms::error(), e.to_string()).encode(env), - } + try_reply_async(env, async move { + sock.send_to((addr, port).into(), &msg) + .await + .map(|_| ()) + .map_err(term_err) + }) + .pipe(Ok) } -#[rustler::nif(schedule = "DirtyIo")] -fn udp_recv(env: rustler::Env, sock: ResourceArc) -> NifResult { - let (who, msg) = sock.inner.recv_from_bytes_blocking().map_err(term_err)?; +#[rustler::nif] +fn udp_recv(env: rustler::Env, sock: ResourceArc) -> AsyncReply { + let sock = sock.inner.clone(); - Ok((atoms::ok(), ErlIp(who.ip()), who.port(), msg.to_vec()).encode(env)) + try_reply_async(env, async move { + sock.recv_from_bytes() + .await + .map(|(s, msg)| (ErlIp(s.ip()), s.port(), msg.to_vec())) + .map_err(term_err) + }) } #[rustler::nif] fn udp_local_addr(sock: ResourceArc) -> impl Encoder { - crate::sockaddr_to_erl(sock.inner.local_addr()) + sockaddr_to_erl(sock.inner.local_addr()) } From ce2bf903ecc23014fa365438f102978da3313744 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 24 Apr 2026 07:50:03 -0400 Subject: [PATCH 5/5] elixir: async tests Provide tests to exercise the async callback type conversions. Signed-off-by: Nathan Perry Change-Id: Iafe01153ca52b9f618c80666de5e6f186a6a6964 Change-Id: I9d1d3e23de8a558fc6f0f81677b6237e6a6a6964 Signed-off-by: Nathan Perry --- ts_elixir/config/config.exs | 7 ++ ts_elixir/config/dev.exs | 4 ++ ts_elixir/config/prod.exs | 4 ++ ts_elixir/config/test.exs | 4 ++ ts_elixir/lib/tailscale/native.ex | 17 ++++- ts_elixir/native/ts_elixir/Cargo.toml | 5 ++ ts_elixir/native/ts_elixir/src/lib.rs | 2 + .../native/ts_elixir/src/testing_nifs.rs | 50 ++++++++++++++ ts_elixir/test/async_callback_test.exs | 69 +++++++++++++++++++ 9 files changed, 159 insertions(+), 3 deletions(-) create mode 100644 ts_elixir/config/config.exs create mode 100644 ts_elixir/config/dev.exs create mode 100644 ts_elixir/config/prod.exs create mode 100644 ts_elixir/config/test.exs create mode 100644 ts_elixir/native/ts_elixir/src/testing_nifs.rs create mode 100644 ts_elixir/test/async_callback_test.exs diff --git a/ts_elixir/config/config.exs b/ts_elixir/config/config.exs new file mode 100644 index 00000000..a43c1d18 --- /dev/null +++ b/ts_elixir/config/config.exs @@ -0,0 +1,7 @@ +import Config + +config :tailscale, + testing_nifs: false, + profile: :debug + +import_config "#{config_env()}.exs" diff --git a/ts_elixir/config/dev.exs b/ts_elixir/config/dev.exs new file mode 100644 index 00000000..03b03c6f --- /dev/null +++ b/ts_elixir/config/dev.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + testing_nifs: true diff --git a/ts_elixir/config/prod.exs b/ts_elixir/config/prod.exs new file mode 100644 index 00000000..e58378b4 --- /dev/null +++ b/ts_elixir/config/prod.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + profile: :release diff --git a/ts_elixir/config/test.exs b/ts_elixir/config/test.exs new file mode 100644 index 00000000..03b03c6f --- /dev/null +++ b/ts_elixir/config/test.exs @@ -0,0 +1,4 @@ +import Config + +config :tailscale, + testing_nifs: true diff --git a/ts_elixir/lib/tailscale/native.ex b/ts_elixir/lib/tailscale/native.ex index 76ec50cb..eed11358 100644 --- a/ts_elixir/lib/tailscale/native.ex +++ b/ts_elixir/lib/tailscale/native.ex @@ -1,9 +1,20 @@ defmodule Tailscale.Native do + @moduledoc false + + @testing_nifs Application.compile_env!(:tailscale, :testing_nifs) + @profile Application.compile_env!(:tailscale, :profile) + + @features (if @testing_nifs do + ["testing-nifs"] + else + [] + end) + use Rustler, otp_app: :tailscale, - crate: :ts_elixir - - @moduledoc false + crate: :ts_elixir, + mode: @profile, + features: @features # The Elixir side of the Rustler bindings to `tailscale-rs`. # diff --git a/ts_elixir/native/ts_elixir/Cargo.toml b/ts_elixir/native/ts_elixir/Cargo.toml index 8b6c8028..67bc94d3 100644 --- a/ts_elixir/native/ts_elixir/Cargo.toml +++ b/ts_elixir/native/ts_elixir/Cargo.toml @@ -20,6 +20,11 @@ tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter"] } +[features] +# Additional testing functions that can directly trigger panics and errors to exercise the code in +# development. +testing-nifs = [] + [lib] crate-type = ["cdylib"] diff --git a/ts_elixir/native/ts_elixir/src/lib.rs b/ts_elixir/native/ts_elixir/src/lib.rs index 52820651..7cabb8e4 100644 --- a/ts_elixir/native/ts_elixir/src/lib.rs +++ b/ts_elixir/native/ts_elixir/src/lib.rs @@ -16,6 +16,8 @@ mod helpers; mod ip_or_self; mod node_info; mod tcp; +#[cfg(feature = "testing-nifs")] +mod testing_nifs; mod udp; use async_reply::{AsyncReply, try_reply_async}; diff --git a/ts_elixir/native/ts_elixir/src/testing_nifs.rs b/ts_elixir/native/ts_elixir/src/testing_nifs.rs new file mode 100644 index 00000000..d761cfbb --- /dev/null +++ b/ts_elixir/native/ts_elixir/src/testing_nifs.rs @@ -0,0 +1,50 @@ +//! NIFs that intentionally return errors, panic, and raise exceptions. +//! +//! These are intended for testing the async message passing code and require the +//! `testing-nifs` feature flag to be enabled. + +use rustler::{Env, Error}; + +use crate::async_reply::{AsyncReply, try_reply_async}; + +#[rustler::nif] +pub fn async_panic(env: Env, msg: Option) -> AsyncReply { + try_reply_async(env, async move { + if let Some(msg) = msg { + panic!("{msg}"); + } else { + panic!() + } + + // Needed to indicate return type + #[allow(unreachable_code)] + Ok(()) + }) +} + +#[rustler::nif] +pub fn async_error<'e>(env: Env<'e>, s: String, atom: bool) -> AsyncReply<'e> { + try_reply_async(env, async move { + Result::<(), _>::Err(if atom { + Error::Atom(String::leak(s)) + } else { + Error::Term(Box::new(s)) + }) + }) +} + +#[rustler::nif] +pub fn async_raise<'e>(env: Env<'e>, s: String, atom: bool) -> AsyncReply<'e> { + try_reply_async(env, async move { + Result::<(), _>::Err(if atom { + Error::RaiseAtom(String::leak(s)) + } else { + Error::RaiseTerm(Box::new(s)) + }) + }) +} + +#[rustler::nif] +pub fn async_badarg<'e>(env: Env<'e>) -> AsyncReply<'e> { + try_reply_async(env, async move { Result::<(), _>::Err(Error::BadArg) }) +} diff --git a/ts_elixir/test/async_callback_test.exs b/ts_elixir/test/async_callback_test.exs new file mode 100644 index 00000000..d4b8d844 --- /dev/null +++ b/ts_elixir/test/async_callback_test.exs @@ -0,0 +1,69 @@ +defmodule Tailscale.Test.AsyncCallbacks do + use ExUnit.Case, async: true + require Tailscale.Util + alias Tailscale.Native + + defmacrop await(block, local) do + if local do + quote do + Tailscale.Util.await_local(unquote(block)) + end + else + quote do + Tailscale.Util.await(unquote(block)) + end + end + end + + for local <- [true, false] do + describe "async calls (local: #{local})" do + for msg <- ["msg", nil] do + test "panic (msg: #{msg})" do + result = await(Native.async_panic(unquote(msg)), local) + + {:error, {:nif_panic, arg}} = result + + if unquote(msg) != nil do + assert(arg == "msg") + end + end + end + + for atom <- [true, false] do + test "error (atom: #{atom})" do + assert( + await(Native.async_error("msg", unquote(atom)), local) == + {:error, + if unquote(atom) do + :msg + else + "msg" + end} + ) + end + + test("raise (atom: #{atom})") do + assert_raise RuntimeError, fn -> + msg = + if unquote(atom) do + "Elixir.RuntimeError" + else + "msg" + end + + await( + Native.async_raise(msg, unquote(atom)), + local + ) + end + end + end + + test "badarg" do + assert_raise ArgumentError, fn -> + await(Native.async_badarg(), local) + end + end + end + end +end