From 0d5a528140930b50d0c5d4b652a5b0d12bbac10d Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 28 Mar 2026 16:02:19 +0100 Subject: [PATCH] Small enhancements. --- src/nats_cls.rs | 30 ++++++++++++++++-------------- src/utils/headers.rs | 6 +++--- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/nats_cls.rs b/src/nats_cls.rs index 07f1665..a96a431 100644 --- a/src/nats_cls.rs +++ b/src/nats_cls.rs @@ -28,13 +28,15 @@ pub struct NatsCls { request_timeout: Option, } -/// Helper to read the client from the `RwLock`. Returns a clone of the Client if present. -fn get_client(session: &RwLock>) -> NatsrpyResult { - session - .read() - .map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))? - .clone() - .ok_or(NatsrpyError::NotInitialized) +impl NatsCls { + // Small utility for getting nats session. + fn get_client(&self) -> NatsrpyResult { + self.nats_session + .read() + .map_err(|_| NatsrpyError::PoisonedLock)? + .clone() + .ok_or(NatsrpyError::NotInitialized) + } } #[pyo3::pymethods] @@ -137,7 +139,7 @@ impl NatsCls { reply: Option, err_on_disconnect: bool, ) -> NatsrpyResult> { - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; let data = bytes::Bytes::from(payload); let headermap = headers .map(async_nats::HeaderMap::from_pydict) @@ -175,7 +177,7 @@ impl NatsCls { inbox: Option, timeout: Option, ) -> NatsrpyResult> { - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; let data = payload.map(bytes::Bytes::from); let headermap = headers .map(async_nats::HeaderMap::from_pydict) @@ -198,7 +200,7 @@ impl NatsCls { pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Draining NATS session"); - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; natsrpy_future(py, async move { client.drain().await?; Ok(()) @@ -214,7 +216,7 @@ impl NatsCls { queue: Option, ) -> NatsrpyResult> { log::debug!("Subscribing to '{subject}'"); - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; natsrpy_future(py, async move { let subscriber = if let Some(queue) = queue { client.queue_subscribe(subject, queue).await? @@ -258,7 +260,7 @@ impl NatsCls { "Either domain or api_prefix should be specified, not both.", ))); } - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; natsrpy_future(py, async move { let mut builder = async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit); @@ -288,7 +290,7 @@ impl NatsCls { pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Closing nats session"); let session = self.nats_session.clone(); - let client = get_client(&session)?; + let client = self.get_client()?; // Set session to None immediately so no new operations can start. { let mut guard = session @@ -304,7 +306,7 @@ impl NatsCls { pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Flushing streams"); - let client = get_client(&self.nats_session)?; + let client = self.get_client()?; natsrpy_future(py, async move { client.flush().await?; Ok(()) diff --git a/src/utils/headers.rs b/src/utils/headers.rs index 30b2d00..3ad9671 100644 --- a/src/utils/headers.rs +++ b/src/utils/headers.rs @@ -14,14 +14,14 @@ impl NatsrpyHeadermapExt for async_nats::HeaderMap { fn from_pydict(pydict: Bound) -> NatsrpyResult { let mut headermap = Self::new(); for (name, val) in pydict { - let rs_name = name.extract::()?; - if let Ok(parsed_str) = val.extract::() { + let rs_name = name.extract::<&str>()?; + if let Ok(parsed_str) = val.extract::<&str>() { headermap.insert(rs_name, parsed_str); continue; } if let Ok(parsed_list) = val.extract::>() { for inner in parsed_list { - headermap.append(rs_name.as_str(), inner); + headermap.append(rs_name, inner); } continue; }