diff --git a/crates/nu-plugin/src/lib.rs b/crates/nu-plugin/src/lib.rs index 799ffd96c0..6bd689d4d4 100644 --- a/crates/nu-plugin/src/lib.rs +++ b/crates/nu-plugin/src/lib.rs @@ -59,6 +59,7 @@ mod plugin; mod protocol; mod sequence; mod serializers; +mod util; pub use plugin::{ serve_plugin, EngineInterface, Plugin, PluginCommand, PluginEncoder, SimplePluginCommand, diff --git a/crates/nu-plugin/src/plugin/context.rs b/crates/nu-plugin/src/plugin/context.rs index 0874252263..18c50a6024 100644 --- a/crates/nu-plugin/src/plugin/context.rs +++ b/crates/nu-plugin/src/plugin/context.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::HashMap, sync::{atomic::AtomicBool, Arc}, }; @@ -10,6 +11,8 @@ use nu_protocol::{ Config, IntoSpanned, IoStream, PipelineData, PluginIdentity, ShellError, Span, Spanned, Value, }; +use crate::util::MutableCow; + /// Object safe trait for abstracting operations required of the plugin context. pub(crate) trait PluginExecutionContext: Send + Sync { /// The [Span] for the command execution (`call.head`) @@ -26,8 +29,10 @@ pub(crate) trait PluginExecutionContext: Send + Sync { fn get_env_var(&self, name: &str) -> Result, ShellError>; /// Get all environment variables fn get_env_vars(&self) -> Result, ShellError>; - // Get current working directory + /// Get current working directory fn get_current_dir(&self) -> Result, ShellError>; + /// Set an environment variable + fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError>; /// Evaluate a closure passed to the plugin fn eval_closure( &self, @@ -37,33 +42,35 @@ pub(crate) trait PluginExecutionContext: Send + Sync { redirect_stdout: bool, redirect_stderr: bool, ) -> Result; + /// Create an owned version of the context with `'static` lifetime + fn boxed(&self) -> Box; } -/// The execution context of a plugin command. -pub(crate) struct PluginExecutionCommandContext { +/// The execution context of a plugin command. Can be borrowed. +pub(crate) struct PluginExecutionCommandContext<'a> { identity: Arc, - engine_state: EngineState, - stack: Stack, - call: Call, + engine_state: Cow<'a, EngineState>, + stack: MutableCow<'a, Stack>, + call: Cow<'a, Call>, } -impl PluginExecutionCommandContext { +impl<'a> PluginExecutionCommandContext<'a> { pub fn new( identity: Arc, - engine_state: &EngineState, - stack: &Stack, - call: &Call, - ) -> PluginExecutionCommandContext { + engine_state: &'a EngineState, + stack: &'a mut Stack, + call: &'a Call, + ) -> PluginExecutionCommandContext<'a> { PluginExecutionCommandContext { identity, - engine_state: engine_state.clone(), - stack: stack.clone(), - call: call.clone(), + engine_state: Cow::Borrowed(engine_state), + stack: MutableCow::Borrowed(stack), + call: Cow::Borrowed(call), } } } -impl PluginExecutionContext for PluginExecutionCommandContext { +impl<'a> PluginExecutionContext for PluginExecutionCommandContext<'a> { fn command_span(&self) -> Span { self.call.head } @@ -131,6 +138,11 @@ impl PluginExecutionContext for PluginExecutionCommandContext { Ok(cwd.into_spanned(self.call.head)) } + fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError> { + self.stack.add_env_var(name, value); + Ok(()) + } + fn eval_closure( &self, closure: Spanned, @@ -191,6 +203,15 @@ impl PluginExecutionContext for PluginExecutionCommandContext { eval_block_with_early_return(&self.engine_state, stack, block, input) } + + fn boxed(&self) -> Box { + Box::new(PluginExecutionCommandContext { + identity: self.identity.clone(), + engine_state: Cow::Owned(self.engine_state.clone().into_owned()), + stack: self.stack.owned(), + call: Cow::Owned(self.call.clone().into_owned()), + }) + } } /// A bogus execution context for testing that doesn't really implement anything properly @@ -239,6 +260,12 @@ impl PluginExecutionContext for PluginExecutionBogusContext { }) } + fn add_env_var(&mut self, _name: String, _value: Value) -> Result<(), ShellError> { + Err(ShellError::NushellFailed { + msg: "add_env_var not implemented on bogus".into(), + }) + } + fn eval_closure( &self, _closure: Spanned, @@ -251,4 +278,8 @@ impl PluginExecutionContext for PluginExecutionBogusContext { msg: "eval_closure not implemented on bogus".into(), }) } + + fn boxed(&self) -> Box { + Box::new(PluginExecutionBogusContext) + } } diff --git a/crates/nu-plugin/src/plugin/declaration.rs b/crates/nu-plugin/src/plugin/declaration.rs index b1f66545f6..3f57908f0d 100644 --- a/crates/nu-plugin/src/plugin/declaration.rs +++ b/crates/nu-plugin/src/plugin/declaration.rs @@ -108,12 +108,12 @@ impl Command for PluginDeclaration { })?; // Create the context to execute in - this supports engine calls and custom values - let context = Arc::new(PluginExecutionCommandContext::new( + let mut context = PluginExecutionCommandContext::new( self.source.identity.clone(), engine_state, stack, call, - )); + ); plugin.run( CallInfo { @@ -121,7 +121,7 @@ impl Command for PluginDeclaration { call: evaluated_call, input, }, - context, + &mut context, ) } diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index 37c1c6fba9..c4ebf7abfd 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -458,6 +458,9 @@ impl EngineInterface { EngineCall::GetEnvVar(name) => (EngineCall::GetEnvVar(name), Default::default()), EngineCall::GetEnvVars => (EngineCall::GetEnvVars, Default::default()), EngineCall::GetCurrentDir => (EngineCall::GetCurrentDir, Default::default()), + EngineCall::AddEnvVar(name, value) => { + (EngineCall::AddEnvVar(name, value), Default::default()) + } }; // Register the channel @@ -622,6 +625,30 @@ impl EngineInterface { } } + /// Set an environment variable in the caller's scope. + /// + /// If called after the plugin response has already been sent (i.e. during a stream), this will + /// only affect the environment for engine calls related to this plugin call, and will not be + /// propagated to the environment of the caller. + /// + /// # Example + /// ```rust,no_run + /// # use nu_protocol::{Value, ShellError}; + /// # use nu_plugin::EngineInterface; + /// # fn example(engine: &EngineInterface) -> Result<(), ShellError> { + /// engine.add_env_var("FOO", Value::test_string("bar")) + /// # } + /// ``` + pub fn add_env_var(&self, name: impl Into, value: Value) -> Result<(), ShellError> { + match self.engine_call(EngineCall::AddEnvVar(name.into(), value))? { + EngineCallResponse::PipelineData(_) => Ok(()), + EngineCallResponse::Error(err) => Err(err), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::AddEnvVar".into(), + }), + } + } + /// Ask the engine to evaluate a closure. Input to the closure is passed as a stream, and the /// output is available as a stream. /// diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index 918055fb3c..372f3adf7a 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -953,6 +953,20 @@ fn interface_get_env_vars() -> Result<(), ShellError> { Ok(()) } +#[test] +fn interface_add_env_var() -> Result<(), ShellError> { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.interface_for_context(0); + + start_fake_plugin_call_responder(manager, 1, move |_| EngineCallResponse::empty()); + + interface.add_env_var("FOO", Value::test_string("bar"))?; + + assert!(test.has_unconsumed_write()); + Ok(()) +} + #[test] fn interface_eval_closure_with_stream() -> Result<(), ShellError> { let test = TestCase::new(); diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index 4aebb1a1d0..3e678c9440 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -2,7 +2,7 @@ use std::{ collections::{btree_map, BTreeMap}, - sync::{mpsc, Arc, OnceLock}, + sync::{atomic::AtomicBool, mpsc, Arc, OnceLock}, }; use nu_protocol::{ @@ -44,8 +44,7 @@ enum ReceivedPluginCallMessage { } /// Context for plugin call execution -#[derive(Clone)] -pub(crate) struct Context(Arc); +pub(crate) struct Context(Box); impl std::fmt::Debug for Context { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -70,7 +69,7 @@ struct PluginInterfaceState { /// Sequence for generating stream ids stream_id_sequence: Sequence, /// Sender to subscribe to a plugin call response - plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>, + plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallState)>, /// An error that should be propagated to further plugin calls error: OnceLock, /// The synchronized output writer @@ -91,14 +90,15 @@ impl std::fmt::Debug for PluginInterfaceState { } } -/// Sent to the [`PluginInterfaceManager`] before making a plugin call to indicate interest in its -/// response. +/// State that the manager keeps for each plugin call during its lifetime. #[derive(Debug)] -struct PluginCallSubscription { +struct PluginCallState { /// The sender back to the thread that is waiting for the plugin call response sender: Option>, - /// Optional context for the environment of a plugin call for servicing engine calls - context: Option, + /// Interrupt signal to be used for stream iterators + ctrlc: Option>, + /// Channel to receive context on to be used if needed + context_rx: Option>, /// Number of streams that still need to be read from the plugin call response remaining_streams_to_read: i32, } @@ -112,10 +112,10 @@ pub(crate) struct PluginInterfaceManager { stream_manager: StreamManager, /// Protocol version info, set after `Hello` received protocol_info: Option, - /// Subscriptions for messages related to plugin calls - plugin_call_subscriptions: BTreeMap, + /// State related to plugin calls + plugin_call_states: BTreeMap, /// Receiver for plugin call subscriptions - plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>, + plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallState)>, /// Tracker for which plugin call streams being read belong to /// /// This is necessary so we know when we can remove context for plugin calls @@ -142,7 +142,7 @@ impl PluginInterfaceManager { }), stream_manager: StreamManager::new(), protocol_info: None, - plugin_call_subscriptions: BTreeMap::new(), + plugin_call_states: BTreeMap::new(), plugin_call_subscription_receiver: subscription_rx, plugin_call_input_streams: BTreeMap::new(), gc: None, @@ -158,9 +158,9 @@ impl PluginInterfaceManager { /// Consume pending messages in the `plugin_call_subscription_receiver` fn receive_plugin_call_subscriptions(&mut self) { - while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() { - if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) { - e.insert(subscription); + while let Ok((id, state)) = self.plugin_call_subscription_receiver.try_recv() { + if let btree_map::Entry::Vacant(e) = self.plugin_call_states.entry(id) { + e.insert(state); } else { log::warn!("Duplicate plugin call ID ignored: {id}"); } @@ -172,8 +172,8 @@ impl PluginInterfaceManager { self.plugin_call_input_streams.insert(stream_id, call_id); // Increment the number of streams on the subscription so context stays alive self.receive_plugin_call_subscriptions(); - if let Some(sub) = self.plugin_call_subscriptions.get_mut(&call_id) { - sub.remaining_streams_to_read += 1; + if let Some(state) = self.plugin_call_states.get_mut(&call_id) { + state.remaining_streams_to_read += 1; } // Add a lock to the garbage collector for each stream if let Some(ref gc) = self.gc { @@ -184,8 +184,7 @@ impl PluginInterfaceManager { /// Track the end of an incoming stream fn recv_stream_ended(&mut self, stream_id: StreamId) { if let Some(call_id) = self.plugin_call_input_streams.remove(&stream_id) { - if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(call_id) - { + if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(call_id) { e.get_mut().remaining_streams_to_read -= 1; // Remove the subscription if there are no more streams to be read. if e.get().remaining_streams_to_read <= 0 { @@ -200,14 +199,14 @@ impl PluginInterfaceManager { } } - /// Find the context corresponding to the given plugin call id - fn get_context(&mut self, id: PluginCallId) -> Result, ShellError> { + /// Find the ctrlc signal corresponding to the given plugin call id + fn get_ctrlc(&mut self, id: PluginCallId) -> Result>, ShellError> { // Make sure we're up to date self.receive_plugin_call_subscriptions(); // Find the subscription and return the context - self.plugin_call_subscriptions + self.plugin_call_states .get(&id) - .map(|sub| sub.context.clone()) + .map(|state| state.ctrlc.clone()) .ok_or_else(|| ShellError::PluginFailedToDecode { msg: format!("Unknown plugin call ID: {id}"), }) @@ -222,7 +221,7 @@ impl PluginInterfaceManager { // Ensure we're caught up on the subscriptions made self.receive_plugin_call_subscriptions(); - if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(id) { + if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(id) { // Remove the subscription sender, since this will be the last message. // // We can spawn a new one if we need it for engine calls. @@ -254,11 +253,23 @@ impl PluginInterfaceManager { ) -> Result<&mpsc::Sender, ShellError> { let interface = self.get_interface(); - if let Some(sub) = self.plugin_call_subscriptions.get_mut(&id) { - if sub.sender.is_none() { + if let Some(state) = self.plugin_call_states.get_mut(&id) { + if state.sender.is_none() { let (tx, rx) = mpsc::channel(); - let context = sub.context.clone(); + let context_rx = + state + .context_rx + .take() + .ok_or_else(|| ShellError::NushellFailed { + msg: "Tried to spawn the fallback engine call handler more than once" + .into(), + })?; let handler = move || { + // We receive on the thread so that we don't block the reader thread + let mut context = context_rx + .recv() + .ok() // The plugin call won't send context if it's not required. + .map(|c| c.0); for msg in rx { // This thread only handles engine calls. match msg { @@ -266,7 +277,7 @@ impl PluginInterfaceManager { if let Err(err) = interface.handle_engine_call( engine_call_id, engine_call, - &context, + context.as_deref_mut(), ) { log::warn!( "Error in plugin post-response engine call handler: \ @@ -286,8 +297,8 @@ impl PluginInterfaceManager { .name("plugin engine call handler".into()) .spawn(handler) .expect("failed to spawn thread"); - sub.sender = Some(tx); - Ok(sub.sender.as_ref().unwrap_or_else(|| unreachable!())) + state.sender = Some(tx); + Ok(state.sender.as_ref().unwrap_or_else(|| unreachable!())) } else { Err(ShellError::NushellFailed { msg: "Tried to spawn the fallback engine call handler before the plugin call \ @@ -313,7 +324,7 @@ impl PluginInterfaceManager { self.receive_plugin_call_subscriptions(); // Don't remove the sender, as there could be more calls or responses - if let Some(subscription) = self.plugin_call_subscriptions.get(&plugin_call_id) { + if let Some(subscription) = self.plugin_call_states.get(&plugin_call_id) { let msg = ReceivedPluginCallMessage::EngineCall(engine_call_id, call); // Call if there's an error sending the engine call let send_error = |this: &Self| { @@ -374,9 +385,7 @@ impl PluginInterfaceManager { let _ = self.stream_manager.broadcast_read_error(err.clone()); // Error to call waiters self.receive_plugin_call_subscriptions(); - for subscription in - std::mem::take(&mut self.plugin_call_subscriptions).into_values() - { + for subscription in std::mem::take(&mut self.plugin_call_states).into_values() { let _ = subscription .sender .as_ref() @@ -460,15 +469,14 @@ impl InterfaceManager for PluginInterfaceManager { PluginCallResponse::PipelineData(data) => { // If there's an error with initializing this stream, change it to a plugin // error response, but send it anyway - let exec_context = self.get_context(id)?; - let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc()); + let ctrlc = self.get_ctrlc(id)?; // Register the streams in the response for stream_id in data.stream_ids() { self.recv_stream_started(id, stream_id); } - match self.read_pipeline_data(data, ctrlc) { + match self.read_pipeline_data(data, ctrlc.as_ref()) { Ok(data) => PluginCallResponse::PipelineData(data), Err(err) => PluginCallResponse::Error(err.into()), } @@ -485,14 +493,14 @@ impl InterfaceManager for PluginInterfaceManager { } PluginOutput::EngineCall { context, id, call } => { // Handle reading the pipeline data, if any - let exec_context = self.get_context(context)?; - let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc()); + let ctrlc = self.get_ctrlc(context)?; let call = match call { EngineCall::GetConfig => Ok(EngineCall::GetConfig), EngineCall::GetPluginConfig => Ok(EngineCall::GetPluginConfig), EngineCall::GetEnvVar(name) => Ok(EngineCall::GetEnvVar(name)), EngineCall::GetEnvVars => Ok(EngineCall::GetEnvVars), EngineCall::GetCurrentDir => Ok(EngineCall::GetCurrentDir), + EngineCall::AddEnvVar(name, value) => Ok(EngineCall::AddEnvVar(name, value)), EngineCall::EvalClosure { closure, mut positional, @@ -504,14 +512,15 @@ impl InterfaceManager for PluginInterfaceManager { for arg in positional.iter_mut() { PluginCustomValue::add_source(arg, &self.state.source); } - self.read_pipeline_data(input, ctrlc) - .map(|input| EngineCall::EvalClosure { + self.read_pipeline_data(input, ctrlc.as_ref()).map(|input| { + EngineCall::EvalClosure { closure, positional, input, redirect_stdout, redirect_stderr, - }) + } + }) } }; match call { @@ -622,7 +631,8 @@ impl PluginInterface { fn write_plugin_call( &self, call: PluginCall, - context: Option, + ctrlc: Option>, + context_rx: mpsc::Receiver, ) -> Result< ( PipelineDataWriter, @@ -662,9 +672,10 @@ impl PluginInterface { .plugin_call_subscription_sender .send(( id, - PluginCallSubscription { + PluginCallState { sender: Some(tx), - context, + ctrlc, + context_rx: Some(context_rx), remaining_streams_to_read: 0, }, )) @@ -703,19 +714,26 @@ impl PluginInterface { fn receive_plugin_call_response( &self, rx: mpsc::Receiver, - context: &Option, + mut context: Option<&mut (dyn PluginExecutionContext + '_)>, + context_tx: mpsc::Sender, ) -> Result, ShellError> { // Handle message from receiver for msg in rx { match msg { ReceivedPluginCallMessage::Response(resp) => { + if resp.has_stream() { + // If the response has a stream, we need to register the context + if let Some(context) = context { + let _ = context_tx.send(Context(context.boxed())); + } + } return Ok(resp); } ReceivedPluginCallMessage::Error(err) => { return Err(err); } ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => { - self.handle_engine_call(engine_call_id, engine_call, context)?; + self.handle_engine_call(engine_call_id, engine_call, context.as_deref_mut())?; } } } @@ -730,7 +748,7 @@ impl PluginInterface { &self, engine_call_id: EngineCallId, engine_call: EngineCall, - context: &Option, + context: Option<&mut (dyn PluginExecutionContext + '_)>, ) -> Result<(), ShellError> { let resp = handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error); @@ -763,7 +781,7 @@ impl PluginInterface { fn plugin_call( &self, call: PluginCall, - context: &Option, + context: Option<&mut dyn PluginExecutionContext>, ) -> Result, ShellError> { // Check for an error in the state first, and return it if set. if let Some(error) = self.state.error.get() { @@ -777,17 +795,24 @@ impl PluginInterface { gc.increment_locks(1); } - let (writer, rx) = self.write_plugin_call(call, context.clone())?; + // Create the channel to send context on if needed + let (context_tx, context_rx) = mpsc::channel(); + + let (writer, rx) = self.write_plugin_call( + call, + context.as_ref().and_then(|c| c.ctrlc().cloned()), + context_rx, + )?; // Finish writing stream in the background writer.write_background()?; - self.receive_plugin_call_response(rx, context) + self.receive_plugin_call_response(rx, context, context_tx) } /// Get the command signatures from the plugin. pub(crate) fn get_signature(&self) -> Result, ShellError> { - match self.plugin_call(PluginCall::Signature, &None)? { + match self.plugin_call(PluginCall::Signature, None)? { PluginCallResponse::Signature(sigs) => Ok(sigs), PluginCallResponse::Error(err) => Err(err.into()), _ => Err(ShellError::PluginFailedToDecode { @@ -800,10 +825,9 @@ impl PluginInterface { pub(crate) fn run( &self, call: CallInfo, - context: Arc, + context: &mut dyn PluginExecutionContext, ) -> Result { - let context = Some(Context(context)); - match self.plugin_call(PluginCall::Run(call), &context)? { + match self.plugin_call(PluginCall::Run(call), Some(context))? { PluginCallResponse::PipelineData(data) => Ok(data), PluginCallResponse::Error(err) => Err(err.into()), _ => Err(ShellError::PluginFailedToDecode { @@ -821,7 +845,7 @@ impl PluginInterface { let op_name = op.name(); let span = value.span; let call = PluginCall::CustomValueOp(value, op); - match self.plugin_call(call, &None)? { + match self.plugin_call(call, None)? { PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)), PluginCallResponse::Error(err) => Err(err.into()), _ => Err(ShellError::PluginFailedToDecode { @@ -869,7 +893,7 @@ impl PluginInterface { value.into_spanned(Span::unknown()), CustomValueOp::PartialCmp(other_value), ); - match self.plugin_call(call, &None)? { + match self.plugin_call(call, None)? { PluginCallResponse::Ordering(ordering) => Ok(ordering), PluginCallResponse::Error(err) => Err(err.into()), _ => Err(ShellError::PluginFailedToDecode { @@ -977,56 +1001,53 @@ impl Drop for PluginInterface { /// Handle an engine call. pub(crate) fn handle_engine_call( call: EngineCall, - context: &Option, + context: Option<&mut (dyn PluginExecutionContext + '_)>, ) -> Result, ShellError> { let call_name = call.name(); - let require_context = || { - context.as_ref().ok_or_else(|| ShellError::GenericError { - error: "A plugin execution context is required for this engine call".into(), - msg: format!( - "attempted to call {} outside of a command invocation", - call_name - ), - span: None, - help: Some("this is probably a bug with the plugin".into()), - inner: vec![], - }) - }; + + let context = context.ok_or_else(|| ShellError::GenericError { + error: "A plugin execution context is required for this engine call".into(), + msg: format!( + "attempted to call {} outside of a command invocation", + call_name + ), + span: None, + help: Some("this is probably a bug with the plugin".into()), + inner: vec![], + })?; + match call { EngineCall::GetConfig => { - let context = require_context()?; let config = Box::new(context.get_config()?); Ok(EngineCallResponse::Config(config)) } EngineCall::GetPluginConfig => { - let context = require_context()?; let plugin_config = context.get_plugin_config()?; Ok(plugin_config.map_or_else(EngineCallResponse::empty, EngineCallResponse::value)) } EngineCall::GetEnvVar(name) => { - let context = require_context()?; let value = context.get_env_var(&name)?; Ok(value.map_or_else(EngineCallResponse::empty, EngineCallResponse::value)) } - EngineCall::GetEnvVars => { - let context = require_context()?; - context.get_env_vars().map(EngineCallResponse::ValueMap) - } + EngineCall::GetEnvVars => context.get_env_vars().map(EngineCallResponse::ValueMap), EngineCall::GetCurrentDir => { - let context = require_context()?; let current_dir = context.get_current_dir()?; Ok(EngineCallResponse::value(Value::string( current_dir.item, current_dir.span, ))) } + EngineCall::AddEnvVar(name, value) => { + context.add_env_var(name, value)?; + Ok(EngineCallResponse::empty()) + } EngineCall::EvalClosure { closure, positional, input, redirect_stdout, redirect_stderr, - } => require_context()? + } => context .eval_closure(closure, positional, input, redirect_stdout, redirect_stderr) .map(EngineCallResponse::PipelineData), } diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index dd3c4fe5fd..83eb2ccea8 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -1,7 +1,4 @@ -use std::{ - sync::{mpsc, Arc}, - time::Duration, -}; +use std::{sync::mpsc, time::Duration}; use nu_protocol::{ engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, @@ -24,8 +21,7 @@ use crate::{ }; use super::{ - Context, PluginCallSubscription, PluginInterface, PluginInterfaceManager, - ReceivedPluginCallMessage, + Context, PluginCallState, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage, }; #[test] @@ -187,11 +183,12 @@ fn fake_plugin_call( // Set up a fake plugin call subscription let (tx, rx) = mpsc::channel(); - manager.plugin_call_subscriptions.insert( + manager.plugin_call_states.insert( id, - PluginCallSubscription { + PluginCallState { sender: Some(tx), - context: None, + ctrlc: None, + context_rx: None, remaining_streams_to_read: 0, }, ); @@ -388,7 +385,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { ))?; // ListStream should have one - if let Some(sub) = manager.plugin_call_subscriptions.get(&0) { + if let Some(sub) = manager.plugin_call_states.get(&0) { assert_eq!( 1, sub.remaining_streams_to_read, "ListStream remaining_streams_to_read should be 1" @@ -403,7 +400,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { ); // ExternalStream should have three - if let Some(sub) = manager.plugin_call_subscriptions.get(&1) { + if let Some(sub) = manager.plugin_call_states.get(&1) { assert_eq!( 3, sub.remaining_streams_to_read, "ExternalStream remaining_streams_to_read should be 3" @@ -483,20 +480,25 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError let mut manager = test.plugin("test"); manager.protocol_info = Some(ProtocolInfo::default()); - let bogus = Context(Arc::new(PluginExecutionBogusContext)); + let (context_tx, context_rx) = mpsc::channel(); // Set up a situation identical to what we would find if the response had been read, but there // was still a stream being processed. We have nowhere to send the engine call in that case, // so the manager has to create a place to handle it. - manager.plugin_call_subscriptions.insert( + manager.plugin_call_states.insert( 0, - PluginCallSubscription { + PluginCallState { sender: None, - context: Some(bogus), + ctrlc: None, + context_rx: Some(context_rx), remaining_streams_to_read: 1, }, ); + // The engine will get the context from the channel + let bogus = Context(Box::new(PluginExecutionBogusContext)); + context_tx.send(bogus).expect("failed to send"); + manager.send_engine_call(0, 0, EngineCall::GetConfig)?; // Not really much choice but to wait here, as the thread will have been spawned in the @@ -528,7 +530,7 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError // Whatever was used to make this happen should have been held onto, since spawning a thread // is expensive let sender = &manager - .plugin_call_subscriptions + .plugin_call_states .get(&0) .expect("missing subscription 0") .sender; @@ -546,11 +548,12 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( let mut manager = TestCase::new().plugin("test"); for n in [0, 1] { - manager.plugin_call_subscriptions.insert( + manager.plugin_call_states.insert( n, - PluginCallSubscription { + PluginCallState { sender: None, - context: None, + ctrlc: None, + context_rx: None, remaining_streams_to_read: n as i32, }, ); @@ -562,11 +565,11 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( // 0 should not still be present, but 1 should be assert!( - !manager.plugin_call_subscriptions.contains_key(&0), + !manager.plugin_call_states.contains_key(&0), "didn't clean up when there weren't remaining streams" ); assert!( - manager.plugin_call_subscriptions.contains_key(&1), + manager.plugin_call_states.contains_key(&1), "clean up even though there were remaining streams" ); Ok(()) @@ -578,11 +581,12 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<() manager.protocol_info = Some(ProtocolInfo::default()); for n in [1, 2] { - manager.plugin_call_subscriptions.insert( + manager.plugin_call_states.insert( n, - PluginCallSubscription { + PluginCallState { sender: None, - context: None, + ctrlc: None, + context_rx: None, remaining_streams_to_read: n as i32, }, ); @@ -608,21 +612,21 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<() // Ending 10 should cause 1 to be removed manager.consume(StreamMessage::End(10).into())?; assert!( - !manager.plugin_call_subscriptions.contains_key(&1), + !manager.plugin_call_states.contains_key(&1), "contains(1) after End(10)" ); // Ending 21 should not cause 2 to be removed manager.consume(StreamMessage::End(21).into())?; assert!( - manager.plugin_call_subscriptions.contains_key(&2), + manager.plugin_call_states.contains_key(&2), "!contains(2) after End(21)" ); // Ending 22 should cause 2 to be removed manager.consume(StreamMessage::End(22).into())?; assert!( - !manager.plugin_call_subscriptions.contains_key(&2), + !manager.plugin_call_states.contains_key(&2), "contains(2) after End(22)" ); @@ -728,18 +732,15 @@ fn interface_goodbye() -> Result<(), ShellError> { fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); assert!( - manager.plugin_call_subscriptions.is_empty(), + manager.plugin_call_states.is_empty(), "plugin call subscriptions not empty before start of test" ); let interface = manager.get_interface(); - let _ = interface.write_plugin_call(PluginCall::Signature, None)?; + let _ = interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?; manager.receive_plugin_call_subscriptions(); - assert!( - !manager.plugin_call_subscriptions.is_empty(), - "not registered" - ); + assert!(!manager.plugin_call_states.is_empty(), "not registered"); Ok(()) } @@ -749,7 +750,8 @@ fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> { let manager = test.plugin("test"); let interface = manager.get_interface(); - let (writer, _) = interface.write_plugin_call(PluginCall::Signature, None)?; + let (writer, _) = + interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?; writer.write()?; let written = test.next_written().expect("nothing written"); @@ -778,6 +780,7 @@ fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError CustomValueOp::ToBaseValue, ), None, + mpsc::channel().1, )?; writer.write()?; @@ -812,6 +815,7 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell input: PipelineData::Value(Value::test_int(-1), None), }), None, + mpsc::channel().1, )?; writer.write()?; @@ -850,6 +854,7 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel input: values.clone().into_pipeline_data(None), }), None, + mpsc::channel().1, )?; writer.write()?; @@ -912,7 +917,7 @@ fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> { .expect("failed to send on new channel"); drop(tx); // so we don't deadlock on recv() - let response = interface.receive_plugin_call_response(rx, &None)?; + let response = interface.receive_plugin_call_response(rx, None, mpsc::channel().0)?; assert!( matches!(response, PluginCallResponse::Signature(_)), "wrong response: {response:?}" @@ -935,7 +940,7 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> { drop(tx); // so we don't deadlock on recv() let error = interface - .receive_plugin_call_response(rx, &None) + .receive_plugin_call_response(rx, None, mpsc::channel().0) .expect_err("did not receive error"); assert!( matches!(error, ShellError::ExternalNotSupported { .. }), @@ -958,13 +963,13 @@ fn interface_receive_plugin_call_handles_engine_call() -> Result<(), ShellError> .expect("failed to send on new channel"); // The context should be a bogus context, which will return an error for GetConfig - let context = Some(Context(Arc::new(PluginExecutionBogusContext))); + let mut context = PluginExecutionBogusContext; // We don't actually send a response, so `receive_plugin_call_response` should actually return // an error, but it should still do the engine call drop(tx); interface - .receive_plugin_call_response(rx, &context) + .receive_plugin_call_response(rx, Some(&mut context), mpsc::channel().0) .expect_err("no error even though there was no response"); // Check for the engine call response output @@ -996,15 +1001,16 @@ fn start_fake_plugin_call_responder( std::thread::Builder::new() .name("fake plugin call responder".into()) .spawn(move || { - for (id, sub) in manager + for (id, state) in manager .plugin_call_subscription_receiver .into_iter() .take(take) { for message in f(id) { - sub.sender + state + .sender .as_ref() - .expect("sender is None") + .expect("sender was not set") .send(message) .expect("failed to send"); } @@ -1055,7 +1061,7 @@ fn interface_run() -> Result<(), ShellError> { }, input: PipelineData::Empty, }, - PluginExecutionBogusContext.into(), + &mut PluginExecutionBogusContext, )?; assert_eq!( diff --git a/crates/nu-plugin/src/protocol/mod.rs b/crates/nu-plugin/src/protocol/mod.rs index a241af5bec..d6ad73e857 100644 --- a/crates/nu-plugin/src/protocol/mod.rs +++ b/crates/nu-plugin/src/protocol/mod.rs @@ -348,6 +348,21 @@ impl PluginCallResponse { } } +impl PluginCallResponse { + /// Does this response have a stream? + pub(crate) fn has_stream(&self) -> bool { + match self { + PluginCallResponse::PipelineData(data) => match data { + PipelineData::Empty => false, + PipelineData::Value(..) => false, + PipelineData::ListStream(..) => true, + PipelineData::ExternalStream { .. } => true, + }, + _ => false, + } + } +} + /// Options that can be changed to affect how the engine treats the plugin #[derive(Serialize, Deserialize, Debug, Clone)] pub enum PluginOption { @@ -447,6 +462,8 @@ pub enum EngineCall { GetEnvVars, /// Get current working directory GetCurrentDir, + /// Set an environment variable in the caller's scope + AddEnvVar(String, Value), /// Evaluate a closure with stream input/output EvalClosure { /// The closure to call. @@ -473,6 +490,7 @@ impl EngineCall { EngineCall::GetEnvVar(_) => "GetEnv", EngineCall::GetEnvVars => "GetEnvs", EngineCall::GetCurrentDir => "GetCurrentDir", + EngineCall::AddEnvVar(..) => "AddEnvVar", EngineCall::EvalClosure { .. } => "EvalClosure", } } diff --git a/crates/nu-plugin/src/util/mod.rs b/crates/nu-plugin/src/util/mod.rs new file mode 100644 index 0000000000..92818a4054 --- /dev/null +++ b/crates/nu-plugin/src/util/mod.rs @@ -0,0 +1,3 @@ +mod mutable_cow; + +pub(crate) use mutable_cow::*; diff --git a/crates/nu-plugin/src/util/mutable_cow.rs b/crates/nu-plugin/src/util/mutable_cow.rs new file mode 100644 index 0000000000..e0f7807fe2 --- /dev/null +++ b/crates/nu-plugin/src/util/mutable_cow.rs @@ -0,0 +1,35 @@ +/// Like [`Cow`] but with a mutable reference instead. So not exactly clone-on-write, but can be +/// made owned. +pub enum MutableCow<'a, T> { + Borrowed(&'a mut T), + Owned(T), +} + +impl<'a, T: Clone> MutableCow<'a, T> { + pub fn owned(&self) -> MutableCow<'static, T> { + match self { + MutableCow::Borrowed(r) => MutableCow::Owned((*r).clone()), + MutableCow::Owned(o) => MutableCow::Owned(o.clone()), + } + } +} + +impl<'a, T> std::ops::Deref for MutableCow<'a, T> { + type Target = T; + + fn deref(&self) -> &T { + match self { + MutableCow::Borrowed(r) => r, + MutableCow::Owned(o) => o, + } + } +} + +impl<'a, T> std::ops::DerefMut for MutableCow<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MutableCow::Borrowed(r) => r, + MutableCow::Owned(o) => o, + } + } +} diff --git a/crates/nu_plugin_example/src/commands/nu_example_env.rs b/crates/nu_plugin_example/src/commands/nu_example_env.rs index cca15c1b71..1d239779cc 100644 --- a/crates/nu_plugin_example/src/commands/nu_example_env.rs +++ b/crates/nu_plugin_example/src/commands/nu_example_env.rs @@ -19,6 +19,12 @@ impl SimplePluginCommand for NuExampleEnv { "The name of the environment variable to get", ) .switch("cwd", "Get current working directory instead", None) + .named( + "set", + SyntaxShape::Any, + "Set an environment variable to the value", + None, + ) .search_terms(vec!["example".into(), "env".into()]) .input_output_type(Type::Nothing, Type::Any) } @@ -31,8 +37,22 @@ impl SimplePluginCommand for NuExampleEnv { _input: &Value, ) -> Result { if call.has_flag("cwd")? { - // Get working directory - Ok(Value::string(engine.get_current_dir()?, call.head)) + match call.get_flag_value("set") { + None => { + // Get working directory + Ok(Value::string(engine.get_current_dir()?, call.head)) + } + Some(value) => Err(LabeledError { + label: "Invalid arguments".into(), + msg: "--cwd can't be used with --set".into(), + span: Some(value.span()), + }), + } + } else if let Some(value) = call.get_flag_value("set") { + // Set single env var + let name = call.req::(0)?; + engine.add_env_var(name, value)?; + Ok(Value::nothing(call.head)) } else if let Some(name) = call.opt::(0)? { // Get single env var Ok(engine diff --git a/tests/plugins/env.rs b/tests/plugins/env.rs index 7c33a8e848..83774f15e7 100644 --- a/tests/plugins/env.rs +++ b/tests/plugins/env.rs @@ -42,3 +42,14 @@ fn get_current_dir() { assert!(result.status.success()); assert_eq!(cwd, result.out); } + +#[test] +fn set_env() { + let result = nu_with_plugins!( + cwd: ".", + plugin: ("nu_plugin_example"), + "nu-example-env NUSHELL_OPINION --set=rocks; $env.NUSHELL_OPINION" + ); + assert!(result.status.success()); + assert_eq!("rocks", result.out); +}