hydro_deploy/rust_crate/
ports.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::ops::Deref;
5use std::sync::{Arc, Weak};
6
7use anyhow::{Result, bail};
8use async_recursion::async_recursion;
9use dyn_clone::DynClone;
10use hydro_deploy_integration::ServerPort;
11use tokio::sync::RwLock;
12
13use super::RustCrateService;
14use crate::{ClientStrategy, Host, LaunchedHost, PortNetworkHint, ServerStrategy};
15
16pub trait RustCrateSource: Send + Sync {
17    fn source_path(&self) -> SourcePath;
18    fn record_server_config(&self, config: ServerConfig);
19
20    fn host(&self) -> Arc<dyn Host>;
21    fn server(&self) -> Arc<dyn RustCrateServer>;
22    fn record_server_strategy(&self, config: ServerStrategy);
23
24    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
25        config
26    }
27
28    fn send_to(&self, sink: &dyn RustCrateSink) {
29        let forward_res = sink.instantiate(&self.source_path());
30        if let Ok(instantiated) = forward_res {
31            self.record_server_config(instantiated());
32        } else {
33            drop(forward_res);
34            let instantiated = sink
35                .instantiate_reverse(&self.host(), self.server(), &|p| {
36                    self.wrap_reverse_server_config(p)
37                })
38                .unwrap();
39            self.record_server_strategy(instantiated(sink));
40        }
41    }
42}
43
44pub trait RustCrateServer: DynClone + Debug + Send + Sync {
45    fn get_port(&self) -> ServerPort;
46    fn launched_host(&self) -> Arc<dyn LaunchedHost>;
47}
48
49pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
50
51pub trait RustCrateSink: Any + Send + Sync {
52    /// Instantiate the sink as the source host connecting to the sink host.
53    /// Returns a thunk that can be called to perform mutations that instantiate the sink.
54    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
55
56    /// Instantiate the sink, but as the sink host connecting to the source host.
57    /// Returns a thunk that can be called to perform mutations that instantiate the sink, taking a mutable reference to this sink.
58    fn instantiate_reverse(
59        &self,
60        server_host: &Arc<dyn Host>,
61        server_sink: Arc<dyn RustCrateServer>,
62        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
63    ) -> Result<ReverseSinkInstantiator>;
64}
65
66pub struct TaggedSource {
67    pub source: Arc<dyn RustCrateSource>,
68    pub tag: u32,
69}
70
71impl RustCrateSource for TaggedSource {
72    fn source_path(&self) -> SourcePath {
73        SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
74    }
75
76    fn record_server_config(&self, config: ServerConfig) {
77        self.source.record_server_config(config);
78    }
79
80    fn host(&self) -> Arc<dyn Host> {
81        self.source.host()
82    }
83
84    fn server(&self) -> Arc<dyn RustCrateServer> {
85        self.source.server()
86    }
87
88    fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
89        ServerConfig::Tagged(Box::new(config), self.tag)
90    }
91
92    fn record_server_strategy(&self, config: ServerStrategy) {
93        self.source.record_server_strategy(config);
94    }
95}
96
97pub struct NullSourceSink;
98
99impl RustCrateSource for NullSourceSink {
100    fn source_path(&self) -> SourcePath {
101        SourcePath::Null
102    }
103
104    fn host(&self) -> Arc<dyn Host> {
105        panic!("null source has no host")
106    }
107
108    fn server(&self) -> Arc<dyn RustCrateServer> {
109        panic!("null source has no server")
110    }
111
112    fn record_server_config(&self, _config: ServerConfig) {}
113    fn record_server_strategy(&self, _config: ServerStrategy) {}
114}
115
116impl RustCrateSink for NullSourceSink {
117    fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
118        Ok(Box::new(|| ServerConfig::Null))
119    }
120
121    fn instantiate_reverse(
122        &self,
123        _server_host: &Arc<dyn Host>,
124        _server_sink: Arc<dyn RustCrateServer>,
125        _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
126    ) -> Result<ReverseSinkInstantiator> {
127        Ok(Box::new(|_| ServerStrategy::Null))
128    }
129}
130
131pub struct DemuxSink {
132    pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
133}
134
135impl RustCrateSink for DemuxSink {
136    fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
137        let mut thunk_map = HashMap::new();
138        for (key, target) in &self.demux {
139            thunk_map.insert(*key, target.instantiate(client_host)?);
140        }
141
142        Ok(Box::new(move || {
143            let instantiated_map = thunk_map
144                .into_iter()
145                .map(|(key, thunk)| (key, thunk()))
146                .collect();
147
148            ServerConfig::Demux(instantiated_map)
149        }))
150    }
151
152    fn instantiate_reverse(
153        &self,
154        server_host: &Arc<dyn Host>,
155        server_sink: Arc<dyn RustCrateServer>,
156        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
157    ) -> Result<ReverseSinkInstantiator> {
158        let mut thunk_map = HashMap::new();
159        for (key, target) in &self.demux {
160            thunk_map.insert(
161                *key,
162                target.instantiate_reverse(
163                    server_host,
164                    server_sink.clone(),
165                    // the parent wrapper selects the demux port for the parent defn, so do that first
166                    &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
167                )?,
168            );
169        }
170
171        Ok(Box::new(move |me| {
172            let me = me.downcast_ref::<DemuxSink>().unwrap();
173            let instantiated_map = thunk_map
174                .into_iter()
175                .map(|(key, thunk)| (key, (thunk)(me.demux.get(&key).unwrap())))
176                .collect();
177
178            ServerStrategy::Demux(instantiated_map)
179        }))
180    }
181}
182
183#[derive(Clone, Debug)]
184pub struct RustCratePortConfig {
185    pub service: Weak<RwLock<RustCrateService>>,
186    pub service_host: Arc<dyn Host>,
187    pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
188    pub network_hint: PortNetworkHint,
189    pub port: String,
190    pub merge: bool,
191}
192
193impl RustCratePortConfig {
194    pub fn merge(&self) -> Self {
195        Self {
196            service: self.service.clone(),
197            service_host: self.service_host.clone(),
198            service_server_defns: self.service_server_defns.clone(),
199            network_hint: self.network_hint,
200            port: self.port.clone(),
201            merge: true,
202        }
203    }
204}
205
206impl RustCrateSource for RustCratePortConfig {
207    fn source_path(&self) -> SourcePath {
208        SourcePath::Direct(
209            self.service
210                .upgrade()
211                .unwrap()
212                .try_read()
213                .unwrap()
214                .on
215                .clone(),
216        )
217    }
218
219    fn host(&self) -> Arc<dyn Host> {
220        self.service_host.clone()
221    }
222
223    fn server(&self) -> Arc<dyn RustCrateServer> {
224        let from = self.service.upgrade().unwrap();
225        let from_read = from.try_read().unwrap();
226
227        Arc::new(RustCratePortConfig {
228            service: Arc::downgrade(&from),
229            service_host: from_read.on.clone(),
230            service_server_defns: from_read.server_defns.clone(),
231            network_hint: self.network_hint,
232            port: self.port.clone(),
233            merge: false,
234        })
235    }
236
237    fn record_server_config(&self, config: ServerConfig) {
238        let from = self.service.upgrade().unwrap();
239        let mut from_write = from.try_write().unwrap();
240
241        // TODO(shadaj): if already in this map, we want to broadcast
242        assert!(
243            !from_write.port_to_server.contains_key(&self.port),
244            "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
245        );
246        from_write.port_to_server.insert(self.port.clone(), config);
247    }
248
249    fn record_server_strategy(&self, config: ServerStrategy) {
250        let from = self.service.upgrade().unwrap();
251        let mut from_write = from.try_write().unwrap();
252
253        assert!(!from_write.port_to_bind.contains_key(&self.port));
254        from_write.port_to_bind.insert(self.port.clone(), config);
255    }
256}
257
258impl RustCrateServer for RustCratePortConfig {
259    fn get_port(&self) -> ServerPort {
260        // we are in `deployment.start()`, so no one should be writing
261        let server_defns = self.service_server_defns.try_read().unwrap();
262        server_defns.get(&self.port).unwrap().clone()
263    }
264
265    fn launched_host(&self) -> Arc<dyn LaunchedHost> {
266        self.service_host.launched().unwrap()
267    }
268}
269
270pub enum SourcePath {
271    Null,
272    Direct(Arc<dyn Host>),
273    Many(Arc<dyn Host>),
274    Tagged(Box<SourcePath>, u32),
275}
276
277impl SourcePath {
278    #[expect(
279        clippy::type_complexity,
280        reason = "internals (dyn Fn to defer instantiation)"
281    )]
282    fn plan<T: RustCrateServer + Clone + 'static>(
283        &self,
284        server: &T,
285        server_host: &dyn Host,
286        network_hint: PortNetworkHint,
287    ) -> Result<(Box<dyn FnOnce(&dyn Any) -> ServerStrategy>, ServerConfig)> {
288        match self {
289            SourcePath::Direct(client_host) => {
290                let (conn_type, bind_type) =
291                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
292                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
293                Ok((
294                    Box::new(|host| ServerStrategy::Direct(bind_type(host))),
295                    base_config,
296                ))
297            }
298
299            SourcePath::Many(client_host) => {
300                let (conn_type, bind_type) =
301                    server_host.strategy_as_server(client_host.deref(), network_hint)?;
302                let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
303                Ok((
304                    Box::new(|host| ServerStrategy::Many(bind_type(host))),
305                    base_config,
306                ))
307            }
308
309            SourcePath::Tagged(underlying, tag) => {
310                let (bind_type, base_config) =
311                    underlying.plan(server, server_host, network_hint)?;
312                let tag = *tag;
313                Ok((
314                    Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag)),
315                    ServerConfig::TaggedUnwrap(Box::new(base_config)),
316                ))
317            }
318
319            SourcePath::Null => Ok((Box::new(|_| ServerStrategy::Null), ServerConfig::Null)),
320        }
321    }
322}
323
324impl RustCrateSink for RustCratePortConfig {
325    fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
326        let server = self.service.upgrade().unwrap();
327        let server_read = server.try_read().unwrap();
328
329        let server_host = server_read.on.clone();
330
331        let (bind_type, base_config) =
332            client_path.plan(self, server_host.deref(), self.network_hint)?;
333
334        let server = server.clone();
335        let merge = self.merge;
336        let port = self.port.clone();
337        Ok(Box::new(move || {
338            let mut server_write = server.try_write().unwrap();
339            let bind_type = (bind_type)(&*server_write.on);
340
341            if merge {
342                let merge_config = server_write
343                    .port_to_bind
344                    .entry(port.clone())
345                    .or_insert(ServerStrategy::Merge(vec![]));
346                let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
347                    merge.push(bind_type);
348                    merge.len() - 1
349                } else {
350                    panic!("Expected a merge connection definition")
351                };
352
353                ServerConfig::MergeSelect(Box::new(base_config), merge_index)
354            } else {
355                assert!(!server_write.port_to_bind.contains_key(&port));
356                server_write.port_to_bind.insert(port.clone(), bind_type);
357                base_config
358            }
359        }))
360    }
361
362    fn instantiate_reverse(
363        &self,
364        server_host: &Arc<dyn Host>,
365        server_sink: Arc<dyn RustCrateServer>,
366        wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
367    ) -> Result<ReverseSinkInstantiator> {
368        if !matches!(self.network_hint, PortNetworkHint::Auto) {
369            bail!("Trying to form collection where I am the client, but I have server hint")
370        }
371
372        let client = self.service.upgrade().unwrap();
373        let client_read = client.try_read().unwrap();
374
375        let server_host = server_host.clone();
376
377        let (conn_type, bind_type) =
378            server_host.strategy_as_server(client_read.on.deref(), PortNetworkHint::Auto)?;
379        let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
380
381        let client = client.clone();
382        let merge = self.merge;
383        let port = self.port.clone();
384        Ok(Box::new(move |_| {
385            let mut client_write = client.try_write().unwrap();
386
387            if merge {
388                let merge_config = client_write
389                    .port_to_server
390                    .entry(port.clone())
391                    .or_insert(ServerConfig::Merge(vec![]));
392
393                if let ServerConfig::Merge(merge) = merge_config {
394                    merge.push(client_port);
395                } else {
396                    panic!()
397                };
398            } else {
399                assert!(!client_write.port_to_server.contains_key(&port));
400                client_write
401                    .port_to_server
402                    .insert(port.clone(), client_port);
403            };
404
405            ServerStrategy::Direct((bind_type)(&*client_write.on))
406        }))
407    }
408}
409
410#[derive(Clone, Debug)]
411pub enum ServerConfig {
412    Direct(Arc<dyn RustCrateServer>),
413    Forwarded(Arc<dyn RustCrateServer>),
414    /// A demux that will be used at runtime to listen to many connections.
415    Demux(HashMap<u32, ServerConfig>),
416    /// The other side of a demux, with a port to extract the appropriate connection.
417    DemuxSelect(Box<ServerConfig>, u32),
418    /// A merge that will be used at runtime to combine many connections.
419    Merge(Vec<ServerConfig>),
420    /// The other side of a merge, with a port to extract the appropriate connection.
421    MergeSelect(Box<ServerConfig>, usize),
422    Tagged(Box<ServerConfig>, u32),
423    TaggedUnwrap(Box<ServerConfig>),
424    Null,
425}
426
427impl ServerConfig {
428    pub fn from_strategy(
429        strategy: &ClientStrategy,
430        server: Arc<dyn RustCrateServer>,
431    ) -> ServerConfig {
432        match strategy {
433            ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
434                ServerConfig::Direct(server)
435            }
436            ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
437        }
438    }
439}
440
441#[async_recursion]
442async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
443    match conn {
444        ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
445        ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
446        ServerPort::Demux(demux) => {
447            let mut forwarded_map = HashMap::new();
448            for (key, conn) in demux {
449                forwarded_map.insert(*key, forward_connection(conn, target).await);
450            }
451            ServerPort::Demux(forwarded_map)
452        }
453        ServerPort::Merge(merge) => {
454            let mut forwarded_vec = Vec::new();
455            for conn in merge {
456                forwarded_vec.push(forward_connection(conn, target).await);
457            }
458            ServerPort::Merge(forwarded_vec)
459        }
460        ServerPort::Tagged(underlying, id) => {
461            ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
462        }
463        ServerPort::Null => ServerPort::Null,
464    }
465}
466
467impl ServerConfig {
468    #[async_recursion]
469    pub async fn load_instantiated(
470        &self,
471        select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
472    ) -> ServerPort {
473        match self {
474            ServerConfig::Direct(server) => select(server.get_port()),
475
476            ServerConfig::Forwarded(server) => {
477                let selected = select(server.get_port());
478                forward_connection(&selected, server.launched_host().as_ref()).await
479            }
480
481            ServerConfig::Demux(demux) => {
482                let mut demux_map = HashMap::new();
483                for (key, conn) in demux {
484                    demux_map.insert(*key, conn.load_instantiated(select).await);
485                }
486                ServerPort::Demux(demux_map)
487            }
488
489            ServerConfig::DemuxSelect(underlying, key) => {
490                let key = *key;
491                underlying
492                    .load_instantiated(
493                        &(move |p| {
494                            if let ServerPort::Demux(mut mapping) = p {
495                                select(mapping.remove(&key).unwrap())
496                            } else {
497                                panic!("Expected a demux connection definition")
498                            }
499                        }),
500                    )
501                    .await
502            }
503
504            ServerConfig::Merge(merge) => {
505                let mut merge_vec = Vec::new();
506                for conn in merge {
507                    merge_vec.push(conn.load_instantiated(select).await);
508                }
509                ServerPort::Merge(merge_vec)
510            }
511
512            ServerConfig::MergeSelect(underlying, key) => {
513                let key = *key;
514                underlying
515                    .load_instantiated(
516                        &(move |p| {
517                            if let ServerPort::Merge(mut mapping) = p {
518                                select(mapping.remove(key))
519                            } else {
520                                panic!("Expected a merge connection definition")
521                            }
522                        }),
523                    )
524                    .await
525            }
526
527            ServerConfig::Tagged(underlying, id) => {
528                ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
529            }
530
531            ServerConfig::TaggedUnwrap(underlying) => {
532                let loaded = underlying.load_instantiated(select).await;
533                if let ServerPort::Tagged(underlying, _) = loaded {
534                    *underlying
535                } else {
536                    panic!("Expected a tagged connection definition")
537                }
538            }
539
540            ServerConfig::Null => ServerPort::Null,
541        }
542    }
543}