From ccb0fa748fffaec6476cdc81510bdc363a618c34 Mon Sep 17 00:00:00 2001 From: Grant Limberg Date: Thu, 14 Aug 2025 12:31:42 -0700 Subject: [PATCH] updates & tests. Tests currently need to be run with --test-threads=1. Seems like the instances of the pubsub emulator stomp on each other without that --- rustybits/Cargo.lock | 2 + rustybits/Cargo.toml | 14 +- rustybits/src/ext.rs | 2 + rustybits/src/pubsub/change_listener.rs | 29 +++- rustybits/src/pubsub/member_listener.rs | 187 ++++++++++++++++++++--- rustybits/src/pubsub/network_listener.rs | 187 ++++++++++++++++++++--- 6 files changed, 370 insertions(+), 51 deletions(-) diff --git a/rustybits/Cargo.lock b/rustybits/Cargo.lock index 1f0d77521..fd6b37a34 100644 --- a/rustybits/Cargo.lock +++ b/rustybits/Cargo.lock @@ -3101,6 +3101,8 @@ dependencies = [ "base64 0.21.7", "bytes", "cbindgen", + "gcloud-gax", + "gcloud-googleapis", "gcloud-pubsub", "jwt", "openidconnect", diff --git a/rustybits/Cargo.toml b/rustybits/Cargo.toml index 3f0f82b5c..6c1cfeb3b 100644 --- a/rustybits/Cargo.toml +++ b/rustybits/Cargo.toml @@ -18,6 +18,10 @@ ztcontroller = [ "dep:gcloud-pubsub", "dep:prost", "dep:prost-types", + "dep:gcloud-gax", + "dep:gcloud-googleapis", + "dep:tokio", + "dep:tokio-util", ] [dependencies] @@ -28,8 +32,12 @@ temporal-client = { git = "https://github.com/temporalio/sdk-core", branch = "ma "telemetry", ] } temporal-sdk-core-protos = { git = "https://github.com/temporalio/sdk-core", branch = "master", optional = true } -tokio = { version = "1.43", features = ["full", "rt", "macros"] } -tokio-util = { version = "0.7" } +tokio = { version = "1.43", features = [ + "full", + "rt", + "macros", +], optional = true } +tokio-util = { version = "0.7", optional = true } uuid = { version = "1.4", features = ["v4"] } openidconnect = { version = "3.4", default-features = false, features = [ "reqwest", @@ -46,6 +54,8 @@ thiserror = "1" gcloud-pubsub = { version = "1.3.0", optional = true } prost = { version = "0.14", optional = true, features = ["derive"] } prost-types = { version = "0.14", optional = true } +gcloud-gax = { version = "1.2.0", optional = true } +gcloud-googleapis = { version = "1.2.0", optional = true } [dev-dependencies] testcontainers = { version = "0.24", features = ["blocking"] } diff --git a/rustybits/src/ext.rs b/rustybits/src/ext.rs index 7aa850e65..83fb864e6 100644 --- a/rustybits/src/ext.rs +++ b/rustybits/src/ext.rs @@ -25,6 +25,7 @@ static mut RT: Option = None; static START: std::sync::Once = std::sync::Once::new(); static SHUTDOWN: std::sync::Once = std::sync::Once::new(); +#[cfg(feature = "ztcontroller")] #[no_mangle] pub unsafe extern "C" fn init_async_runtime() { START.call_once(|| { @@ -39,6 +40,7 @@ pub unsafe extern "C" fn init_async_runtime() { }); } +#[cfg(feature = "ztcontroller")] #[no_mangle] #[allow(static_mut_refs)] pub unsafe extern "C" fn shutdown_async_runtime() { diff --git a/rustybits/src/pubsub/change_listener.rs b/rustybits/src/pubsub/change_listener.rs index eab3f8e31..c5ae1f925 100644 --- a/rustybits/src/pubsub/change_listener.rs +++ b/rustybits/src/pubsub/change_listener.rs @@ -24,7 +24,12 @@ impl ChangeListener { ) -> Result> { let config = ClientConfig::default().with_auth().await.unwrap(); let client = Client::new(config).await?; + let topic = client.topic(topic_name); + if !topic.exists(None).await? { + topic.create(None, None).await?; + } + Ok(Self { client, topic, @@ -35,6 +40,14 @@ impl ChangeListener { }) } + /** + * Listens for changes on the topic and sends them to the provided sender. + * + * Listens for up to `listen_timeout` duration, at which point it will stop listening + * and return. listen will have to be called again to continue listening. + * + * If the subscription does not exist, it will create it with the specified configuration. + */ pub async fn listen(&self) -> Result<(), Box> { let config = SubscriptionConfig { enable_message_ordering: true, @@ -87,7 +100,7 @@ impl ChangeListener { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use testcontainers::runners::AsyncRunner; @@ -96,20 +109,22 @@ mod tests { use testcontainers_modules::google_cloud_sdk_emulators::CloudSdk; use tokio; - async fn setup_pubsub_emulator() -> Result<(ContainerAsync, String), Box> { + pub(crate) async fn setup_pubsub_emulator() -> Result<(ContainerAsync, String), Box> + { let container = google_cloud_sdk_emulators::CloudSdk::pubsub().start().await?; let port = container.get_host_port_ipv4(8085).await?; let host = format!("localhost:{}", port); + + unsafe { + std::env::set_var("PUBSUB_EMULATOR_HOST", host.clone()); + } + Ok((container, host)) } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_can_connect_to_pubsub() -> Result<(), Box> { - let (_container, host) = setup_pubsub_emulator().await?; - - unsafe { - std::env::set_var("PUBSUB_EMULATOR_HOST", host); - } + let (_container, _host) = setup_pubsub_emulator().await?; let (tx, _rx) = tokio::sync::mpsc::channel(64); diff --git a/rustybits/src/pubsub/member_listener.rs b/rustybits/src/pubsub/member_listener.rs index 860e59cc4..993eaed14 100644 --- a/rustybits/src/pubsub/member_listener.rs +++ b/rustybits/src/pubsub/member_listener.rs @@ -50,36 +50,181 @@ impl MemberListener { })) } - pub async fn listen(&self) -> Result<(), Box> { + pub async fn listen(self: &Arc) -> Result<(), Box> { self.change_listener.listen().await } - pub fn change_handler(self: Arc) -> Result<(), Box> { + pub async fn change_handler(self: &Arc) -> Result<(), Box> { let this = self.clone(); - tokio::spawn(async move { - let mut rx = this.rx_channel.lock().await; - while let Some(change) = rx.recv().await { - if let Ok(m) = MemberChange::decode(change.as_slice()) { - print!("Received change: {:?}", m); - let j = serde_json::to_string(&m).unwrap(); - let mut buffer = [0; 16384]; - let mut test: &mut [u8] = &mut buffer; - let mut size: usize = 0; - while let Ok(bytes) = test.write(j.as_bytes()) { - if bytes == 0 { - break; - } - size += bytes; + let mut rx = this.rx_channel.lock().await; + while let Some(change) = rx.recv().await { + if let Ok(m) = MemberChange::decode(change.as_slice()) { + let j = serde_json::to_string(&m).unwrap(); + let mut buffer = [0; 16384]; + let mut test: &mut [u8] = &mut buffer; + let mut size: usize = 0; + while let Ok(bytes) = test.write(j.as_bytes()) { + if bytes == 0 { + break; } - let callback = this.callback.lock().await; - let user_ptr = this.user_ptr.load(std::sync::atomic::Ordering::Relaxed); - - (callback)(user_ptr, test.as_ptr(), size); + size += bytes; } + let callback = this.callback.lock().await; + let user_ptr = this.user_ptr.load(std::sync::atomic::Ordering::Relaxed); + + (callback)(user_ptr, test.as_ptr(), size); + } else { + eprintln!("Failed to decode change"); } - }); + } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::pubsub::change_listener::tests::setup_pubsub_emulator; + use crate::pubsub::protobuf::pbmessages::{Member, MemberChange}; + + use gcloud_googleapis::pubsub::v1::PubsubMessage; + use gcloud_pubsub::client::{Client, ClientConfig}; + use std::{ + collections::HashMap, + sync::atomic::{AtomicBool, Ordering}, + }; + + extern "C" fn dummy_callback(user_ptr: *mut c_void, data: *const u8, _size: usize) { + // Dummy callback for testing + assert!(!data.is_null(), "data pointer is null"); + assert!(!user_ptr.is_null(), "user_ptr pointer is null"); + let user_ptr = unsafe { &mut *(user_ptr as *mut TestMemberListener) }; + user_ptr.callback_called(); + println!("Dummy callback invoked"); + } + + struct TestMemberListener { + dummy_callback_called: bool, + } + + impl TestMemberListener { + fn new() -> Self { + Self { dummy_callback_called: false } + } + + fn callback_called(&mut self) { + self.dummy_callback_called = true; + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_member_listener() { + println!("Setting up Pub/Sub emulator for network listener test"); + let (_container, _host) = setup_pubsub_emulator().await.unwrap(); + let mut tester = TestMemberListener::new(); + + let listener = MemberListener::new( + "testctl", + Duration::from_secs(1), + dummy_callback, + &mut tester as *mut TestMemberListener as *mut c_void, + ) + .await + .unwrap(); + + let rt = tokio::runtime::Handle::current(); + + let run = Arc::new(AtomicBool::new(true)); + rt.spawn({ + let run = run.clone(); + let l = listener.clone(); + async move { + while run.load(Ordering::Relaxed) { + match l.listen().await { + Ok(_) => { + println!("Listener exited successfully"); + } + Err(e) => { + println!("Failed to start listener: {}", e); + assert!(false, "Listener failed to start"); + } + } + } + } + }); + + rt.spawn({ + let run = run.clone(); + let l = listener.clone(); + async move { + while run.load(Ordering::Relaxed) { + match l.change_handler().await { + Ok(_) => { + println!("Change handler started successfully"); + } + Err(e) => { + println!("Failed to start change handler: {}", e); + assert!(false, "Change handler failed to start"); + } + } + } + } + }); + + rt.spawn({ + async move { + let client = Client::new(ClientConfig::default()).await.unwrap(); + let topic = client.topic("controller-member-change-stream"); + if !topic.exists(None).await.unwrap() { + topic.create(None, None).await.unwrap(); + } + + let mut publisher = topic.new_publisher(None); + + let nc = MemberChange { + old: Some(Member { + device_id: "test_member".to_string(), + network_id: "test_network".to_string(), + authorized: false, + ..Default::default() + }), + new: Some(Member { + device_id: "test_member".to_string(), + network_id: "test_network".to_string(), + authorized: true, + ..Default::default() + }), + ..Default::default() + }; + + let data = MemberChange::encode_to_vec(&nc); + let message = PubsubMessage { + data: data.into(), + attributes: HashMap::from([("controller_id".to_string(), "testctl".to_string())]), + ordering_key: format!("members-{}", "test_network"), + ..Default::default() + }; + let awaiter = publisher.publish(message).await; + + match awaiter.get().await { + Ok(_) => println!("Message published successfully"), + Err(e) => { + assert!(false, "Failed to publish message: {}", e); + eprintln!("Failed to publish message: {}", e) + } + } + publisher.shutdown().await; + } + }); + + let mut counter = 0; + while !tester.dummy_callback_called && counter < 100 { + tokio::time::sleep(Duration::from_millis(100)).await; + counter += 1; + } + run.store(false, Ordering::Relaxed); + assert!(tester.dummy_callback_called, "Callback was not called"); + } +} diff --git a/rustybits/src/pubsub/network_listener.rs b/rustybits/src/pubsub/network_listener.rs index 111bab893..987369220 100644 --- a/rustybits/src/pubsub/network_listener.rs +++ b/rustybits/src/pubsub/network_listener.rs @@ -51,36 +51,181 @@ impl NetworkListener { })) } - pub async fn listen(&self) -> Result<(), Box> { + pub async fn listen(self: &Arc) -> Result<(), Box> { self.change_listener.listen().await } - pub fn change_handler(self: Arc) -> Result<(), Box> { + pub async fn change_handler(self: &Arc) -> Result<(), Box> { let this = self.clone(); - tokio::spawn(async move { - let mut rx = this.rx_channel.lock().await; - while let Some(change) = rx.recv().await { - if let Ok(m) = NetworkChange::decode(change.as_slice()) { - print!("Received change: {:?}", m); - let j = serde_json::to_string(&m).unwrap(); - let mut buffer = [0; 16384]; - let mut test: &mut [u8] = &mut buffer; - let mut size: usize = 0; - while let Ok(bytes) = test.write(j.as_bytes()) { - if bytes == 0 { - break; // No more space to write - } - size += bytes; + let mut rx = this.rx_channel.lock().await; + while let Some(change) = rx.recv().await { + if let Ok(m) = NetworkChange::decode(change.as_slice()) { + let j = serde_json::to_string(&m).unwrap(); + let mut buffer = [0; 16384]; + let mut test: &mut [u8] = &mut buffer; + let mut size: usize = 0; + while let Ok(bytes) = test.write(j.as_bytes()) { + if bytes == 0 { + break; // No more space to write } - let callback = this.callback.lock().await; - let user_ptr = this.user_ptr.load(Ordering::Relaxed); - - (callback)(user_ptr, test.as_ptr(), size); + size += bytes; } + let callback = this.callback.lock().await; + let user_ptr = this.user_ptr.load(Ordering::Relaxed); + + (callback)(user_ptr, test.as_ptr(), size); + } else { + eprintln!("Failed to decode change"); } - }); + } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::pubsub::change_listener::tests::setup_pubsub_emulator; + use crate::pubsub::protobuf::pbmessages::Network; + + use gcloud_googleapis::pubsub::v1::PubsubMessage; + use gcloud_pubsub::client::{Client, ClientConfig}; + use std::{ + collections::HashMap, + sync::atomic::{AtomicBool, Ordering}, + }; + + extern "C" fn dummy_callback(user_ptr: *mut c_void, data: *const u8, _size: usize) { + // Dummy callback for testing + assert!(!data.is_null(), "data pointer is null"); + assert!(!user_ptr.is_null(), "user_ptr pointer is null"); + let user_ptr = unsafe { &mut *(user_ptr as *mut TestNetworkListenr) }; + user_ptr.callback_called(); + println!("Dummy callback invoked"); + } + + struct TestNetworkListenr { + dummy_callback_called: bool, + } + + impl TestNetworkListenr { + fn new() -> Self { + Self { dummy_callback_called: false } + } + + fn callback_called(&mut self) { + self.dummy_callback_called = true; + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_network_listener() { + println!("Setting up Pub/Sub emulator for network listener test"); + let (_container, _host) = setup_pubsub_emulator().await.unwrap(); + + let mut tester = TestNetworkListenr::new(); + + let listener = NetworkListener::new( + "testctl", + Duration::from_secs(1), + dummy_callback, + &mut tester as *mut TestNetworkListenr as *mut c_void, + ) + .await + .unwrap(); + + let rt = tokio::runtime::Handle::current(); + + let run = Arc::new(AtomicBool::new(true)); + rt.spawn({ + let run = run.clone(); + let l = listener.clone(); + async move { + while run.load(Ordering::Relaxed) { + match l.listen().await { + Ok(_) => { + println!("Listener exited successfully"); + } + Err(e) => { + println!("Failed to start listener: {}", e); + assert!(false, "Listener failed to start"); + } + } + } + } + }); + + rt.spawn({ + let run = run.clone(); + let l = listener.clone(); + async move { + while run.load(Ordering::Relaxed) { + match l.change_handler().await { + Ok(_) => { + println!("Change handler started successfully"); + } + Err(e) => { + println!("Failed to start change handler: {}", e); + assert!(false, "Change handler failed to start"); + } + } + } + } + }); + + rt.spawn({ + async move { + let client = Client::new(ClientConfig::default()).await.unwrap(); + let topic = client.topic("controller-network-change-stream"); + if !topic.exists(None).await.unwrap() { + topic.create(None, None).await.unwrap(); + } + + let mut publisher = topic.new_publisher(None); + + let nc = NetworkChange { + old: Some(Network { + network_id: "test_network".to_string(), + name: Some("Test Network".to_string()), + ..Default::default() + }), + new: Some(Network { + network_id: "test_network".to_string(), + name: Some("Test Network Updated".to_string()), + ..Default::default() + }), + ..Default::default() + }; + + let data = NetworkChange::encode_to_vec(&nc); + let message = PubsubMessage { + data: data.into(), + attributes: HashMap::from([("controller_id".to_string(), "testctl".to_string())]), + ordering_key: format!("networks-{}", "testctl"), + ..Default::default() + }; + let awaiter = publisher.publish(message).await; + + match awaiter.get().await { + Ok(_) => println!("Message published successfully"), + Err(e) => { + assert!(false, "Failed to publish message: {}", e); + eprintln!("Failed to publish message: {}", e) + } + } + publisher.shutdown().await; + } + }); + + let mut counter = 0; + while !tester.dummy_callback_called && counter < 100 { + tokio::time::sleep(Duration::from_millis(100)).await; + counter += 1; + } + + run.store(false, Ordering::Relaxed); + assert!(tester.dummy_callback_called, "Callback was not called"); + } +}