diff --git a/controller/CV1.cpp b/controller/CV1.cpp index 0ba82f064..0f980c68f 100644 --- a/controller/CV1.cpp +++ b/controller/CV1.cpp @@ -67,6 +67,8 @@ CV1::CV1(const Identity& myId, const char* path, int listenPort, RedisConfig* rc auto span = tracer->StartSpan("cv1::CV1"); auto scope = tracer->WithActiveSpan(span); + rustybits::init_async_runtime(); + char myAddress[64]; _myAddressStr = myId.address().toString(myAddress); _connString = std::string(path); @@ -157,6 +159,8 @@ CV1::~CV1() _smee = NULL; } + rustybits::shutdown_async_runtime(); + _run = 0; std::this_thread::sleep_for(std::chrono::milliseconds(100)); diff --git a/controller/CV2.cpp b/controller/CV2.cpp index db9effb3c..9be950f1d 100644 --- a/controller/CV2.cpp +++ b/controller/CV2.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include using json = nlohmann::json; @@ -43,6 +44,8 @@ CV2::CV2(const Identity& myId, const char* path, int listenPort) : DB(), _pool() auto span = tracer->StartSpan("cv2::CV2"); auto scope = tracer->WithActiveSpan(span); + rustybits::init_async_runtime(); + fprintf(stderr, "CV2::CV2\n"); char myAddress[64]; _myAddressStr = myId.address().toString(myAddress); @@ -83,6 +86,8 @@ CV2::CV2(const Identity& myId, const char* path, int listenPort) : DB(), _pool() CV2::~CV2() { + rustybits::shutdown_async_runtime(); + _run = 0; std::this_thread::sleep_for(std::chrono::milliseconds(100)); diff --git a/rustybits/src/ext.rs b/rustybits/src/ext.rs index 67fe2c6c3..1b87b9dc1 100644 --- a/rustybits/src/ext.rs +++ b/rustybits/src/ext.rs @@ -14,6 +14,38 @@ use std::ffi::{CStr, CString}; use std::os::raw::c_char; use url::Url; +static mut RT: Option = None; + +static START: std::sync::Once = std::sync::Once::new(); +static SHUTDOWN: std::sync::Once = std::sync::Once::new(); + +#[no_mangle] +pub unsafe extern "C" fn init_async_runtime() { + START.call_once(|| { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .thread_name("rust-async-worker") + .enable_all() + .build() + .expect("Failed to create tokio runtime"); + + unsafe { RT = Some(rt) }; + }); +} + +#[no_mangle] +#[allow(static_mut_refs)] +pub unsafe extern "C" fn shutdown_async_runtime() { + SHUTDOWN.call_once(|| { + // Shutdown the tokio runtime + unsafe { + if let Some(rt) = RT.take() { + rt.shutdown_timeout(std::time::Duration::from_secs(5)); + } + } + }); +} + #[cfg(feature = "zeroidc")] use crate::zeroidc::ZeroIDC; @@ -419,8 +451,7 @@ pub unsafe extern "C" fn smee_client_delete(ptr: *mut SmeeClient) { assert!(!ptr.is_null()); Box::from_raw(&mut *ptr) }; - - smee.shutdown(); + drop(smee); } #[cfg(feature = "ztcontroller")] diff --git a/rustybits/src/pubsub/mod.rs b/rustybits/src/pubsub/mod.rs index 2fe3be6ba..719f52e59 100644 --- a/rustybits/src/pubsub/mod.rs +++ b/rustybits/src/pubsub/mod.rs @@ -9,3 +9,21 @@ * On the date above, in accordance with the Business Source License, use * of this software will be governed by version 2.0 of the Apache License. */ + +use gcloud_pubsub::client::{Client, ClientConfig}; + +pub struct PubSubClient { + client: Client, +} + +impl PubSubClient { + pub async fn new() -> Result> { + let config = ClientConfig::default().with_auth().await.unwrap(); + let client = Client::new(config).await?; + + // Assuming a topic name is required for the client + let topic_name = "default-topic".to_string(); + + Ok(Self { client }) + } +} diff --git a/rustybits/src/smeeclient/mod.rs b/rustybits/src/smeeclient/mod.rs index da743e4e5..b68ac9a6b 100644 --- a/rustybits/src/smeeclient/mod.rs +++ b/rustybits/src/smeeclient/mod.rs @@ -17,6 +17,7 @@ use temporal_sdk_core_protos::{ coresdk::AsJsonPayloadExt, temporal::api::enums::v1::{WorkflowIdConflictPolicy, WorkflowIdReusePolicy}, }; +use tokio::runtime::{Handle, Runtime}; use url::Url; use uuid::Uuid; @@ -43,16 +44,13 @@ impl NetworkJoinedParams { } pub struct SmeeClient { - tokio_rt: tokio::runtime::Runtime, client: RetryClient, task_queue: String, } impl SmeeClient { pub fn new(temporal_url: &str, namespace: &str, task_queue: &str) -> Result> { - // start tokio runtime. Required by temporal - let rt = tokio::runtime::Runtime::new()?; - + let rt = Handle::current(); let c = ClientOptionsBuilder::default() .target_url(Url::from_str(temporal_url).unwrap()) .client_name(CLIENT_NAME) @@ -61,11 +59,7 @@ impl SmeeClient { let con = rt.block_on(async { c.connect(namespace.to_string(), None).await })?; - Ok(Self { - tokio_rt: rt, - client: con, - task_queue: task_queue.to_string(), - }) + Ok(Self { client: con, task_queue: task_queue.to_string() }) } pub fn notify_network_joined(&self, params: NetworkJoinedParams) -> Result<(), Box> { @@ -89,7 +83,8 @@ impl SmeeClient { let workflow_id = Uuid::new_v4(); - self.tokio_rt.block_on(async { + let rt = Handle::current(); + rt.block_on(async { println!("calilng start_workflow"); self.client .start_workflow( @@ -105,8 +100,4 @@ impl SmeeClient { Ok(()) } - - pub fn shutdown(self) { - self.tokio_rt.shutdown_timeout(Duration::from_secs(5)) - } }