1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::ops::Deref;
5use std::sync::{Arc, Weak};
6
7use anyhow::{Result, bail};
8use async_recursion::async_recursion;
9use dyn_clone::DynClone;
10use hydro_deploy_integration::ServerPort;
11use tokio::sync::RwLock;
12
13use super::RustCrateService;
14use crate::{ClientStrategy, Host, LaunchedHost, PortNetworkHint, ServerStrategy};
15
16pub trait RustCrateSource: Send + Sync {
17 fn source_path(&self) -> SourcePath;
18 fn record_server_config(&self, config: ServerConfig);
19
20 fn host(&self) -> Arc<dyn Host>;
21 fn server(&self) -> Arc<dyn RustCrateServer>;
22 fn record_server_strategy(&self, config: ServerStrategy);
23
24 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
25 config
26 }
27
28 fn send_to(&self, sink: &dyn RustCrateSink) {
29 let forward_res = sink.instantiate(&self.source_path());
30 if let Ok(instantiated) = forward_res {
31 self.record_server_config(instantiated());
32 } else {
33 drop(forward_res);
34 let instantiated = sink
35 .instantiate_reverse(&self.host(), self.server(), &|p| {
36 self.wrap_reverse_server_config(p)
37 })
38 .unwrap();
39 self.record_server_strategy(instantiated(sink));
40 }
41 }
42}
43
44pub trait RustCrateServer: DynClone + Debug + Send + Sync {
45 fn get_port(&self) -> ServerPort;
46 fn launched_host(&self) -> Arc<dyn LaunchedHost>;
47}
48
49pub type ReverseSinkInstantiator = Box<dyn FnOnce(&dyn Any) -> ServerStrategy>;
50
51pub trait RustCrateSink: Any + Send + Sync {
52 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>>;
55
56 fn instantiate_reverse(
59 &self,
60 server_host: &Arc<dyn Host>,
61 server_sink: Arc<dyn RustCrateServer>,
62 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
63 ) -> Result<ReverseSinkInstantiator>;
64}
65
66pub struct TaggedSource {
67 pub source: Arc<dyn RustCrateSource>,
68 pub tag: u32,
69}
70
71impl RustCrateSource for TaggedSource {
72 fn source_path(&self) -> SourcePath {
73 SourcePath::Tagged(Box::new(self.source.source_path()), self.tag)
74 }
75
76 fn record_server_config(&self, config: ServerConfig) {
77 self.source.record_server_config(config);
78 }
79
80 fn host(&self) -> Arc<dyn Host> {
81 self.source.host()
82 }
83
84 fn server(&self) -> Arc<dyn RustCrateServer> {
85 self.source.server()
86 }
87
88 fn wrap_reverse_server_config(&self, config: ServerConfig) -> ServerConfig {
89 ServerConfig::Tagged(Box::new(config), self.tag)
90 }
91
92 fn record_server_strategy(&self, config: ServerStrategy) {
93 self.source.record_server_strategy(config);
94 }
95}
96
97pub struct NullSourceSink;
98
99impl RustCrateSource for NullSourceSink {
100 fn source_path(&self) -> SourcePath {
101 SourcePath::Null
102 }
103
104 fn host(&self) -> Arc<dyn Host> {
105 panic!("null source has no host")
106 }
107
108 fn server(&self) -> Arc<dyn RustCrateServer> {
109 panic!("null source has no server")
110 }
111
112 fn record_server_config(&self, _config: ServerConfig) {}
113 fn record_server_strategy(&self, _config: ServerStrategy) {}
114}
115
116impl RustCrateSink for NullSourceSink {
117 fn instantiate(&self, _client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
118 Ok(Box::new(|| ServerConfig::Null))
119 }
120
121 fn instantiate_reverse(
122 &self,
123 _server_host: &Arc<dyn Host>,
124 _server_sink: Arc<dyn RustCrateServer>,
125 _wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
126 ) -> Result<ReverseSinkInstantiator> {
127 Ok(Box::new(|_| ServerStrategy::Null))
128 }
129}
130
131pub struct DemuxSink {
132 pub demux: HashMap<u32, Arc<dyn RustCrateSink>>,
133}
134
135impl RustCrateSink for DemuxSink {
136 fn instantiate(&self, client_host: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
137 let mut thunk_map = HashMap::new();
138 for (key, target) in &self.demux {
139 thunk_map.insert(*key, target.instantiate(client_host)?);
140 }
141
142 Ok(Box::new(move || {
143 let instantiated_map = thunk_map
144 .into_iter()
145 .map(|(key, thunk)| (key, thunk()))
146 .collect();
147
148 ServerConfig::Demux(instantiated_map)
149 }))
150 }
151
152 fn instantiate_reverse(
153 &self,
154 server_host: &Arc<dyn Host>,
155 server_sink: Arc<dyn RustCrateServer>,
156 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
157 ) -> Result<ReverseSinkInstantiator> {
158 let mut thunk_map = HashMap::new();
159 for (key, target) in &self.demux {
160 thunk_map.insert(
161 *key,
162 target.instantiate_reverse(
163 server_host,
164 server_sink.clone(),
165 &|p| ServerConfig::DemuxSelect(Box::new(wrap_client_port(p)), *key),
167 )?,
168 );
169 }
170
171 Ok(Box::new(move |me| {
172 let me = me.downcast_ref::<DemuxSink>().unwrap();
173 let instantiated_map = thunk_map
174 .into_iter()
175 .map(|(key, thunk)| (key, (thunk)(me.demux.get(&key).unwrap())))
176 .collect();
177
178 ServerStrategy::Demux(instantiated_map)
179 }))
180 }
181}
182
183#[derive(Clone, Debug)]
184pub struct RustCratePortConfig {
185 pub service: Weak<RwLock<RustCrateService>>,
186 pub service_host: Arc<dyn Host>,
187 pub service_server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
188 pub network_hint: PortNetworkHint,
189 pub port: String,
190 pub merge: bool,
191}
192
193impl RustCratePortConfig {
194 pub fn merge(&self) -> Self {
195 Self {
196 service: self.service.clone(),
197 service_host: self.service_host.clone(),
198 service_server_defns: self.service_server_defns.clone(),
199 network_hint: self.network_hint,
200 port: self.port.clone(),
201 merge: true,
202 }
203 }
204}
205
206impl RustCrateSource for RustCratePortConfig {
207 fn source_path(&self) -> SourcePath {
208 SourcePath::Direct(
209 self.service
210 .upgrade()
211 .unwrap()
212 .try_read()
213 .unwrap()
214 .on
215 .clone(),
216 )
217 }
218
219 fn host(&self) -> Arc<dyn Host> {
220 self.service_host.clone()
221 }
222
223 fn server(&self) -> Arc<dyn RustCrateServer> {
224 let from = self.service.upgrade().unwrap();
225 let from_read = from.try_read().unwrap();
226
227 Arc::new(RustCratePortConfig {
228 service: Arc::downgrade(&from),
229 service_host: from_read.on.clone(),
230 service_server_defns: from_read.server_defns.clone(),
231 network_hint: self.network_hint,
232 port: self.port.clone(),
233 merge: false,
234 })
235 }
236
237 fn record_server_config(&self, config: ServerConfig) {
238 let from = self.service.upgrade().unwrap();
239 let mut from_write = from.try_write().unwrap();
240
241 assert!(
243 !from_write.port_to_server.contains_key(&self.port),
244 "The port configuration is incorrect, for example, are you using a ConnectedDirect instead of a ConnectedDemux?"
245 );
246 from_write.port_to_server.insert(self.port.clone(), config);
247 }
248
249 fn record_server_strategy(&self, config: ServerStrategy) {
250 let from = self.service.upgrade().unwrap();
251 let mut from_write = from.try_write().unwrap();
252
253 assert!(!from_write.port_to_bind.contains_key(&self.port));
254 from_write.port_to_bind.insert(self.port.clone(), config);
255 }
256}
257
258impl RustCrateServer for RustCratePortConfig {
259 fn get_port(&self) -> ServerPort {
260 let server_defns = self.service_server_defns.try_read().unwrap();
262 server_defns.get(&self.port).unwrap().clone()
263 }
264
265 fn launched_host(&self) -> Arc<dyn LaunchedHost> {
266 self.service_host.launched().unwrap()
267 }
268}
269
270pub enum SourcePath {
271 Null,
272 Direct(Arc<dyn Host>),
273 Many(Arc<dyn Host>),
274 Tagged(Box<SourcePath>, u32),
275}
276
277impl SourcePath {
278 #[expect(
279 clippy::type_complexity,
280 reason = "internals (dyn Fn to defer instantiation)"
281 )]
282 fn plan<T: RustCrateServer + Clone + 'static>(
283 &self,
284 server: &T,
285 server_host: &dyn Host,
286 network_hint: PortNetworkHint,
287 ) -> Result<(Box<dyn FnOnce(&dyn Any) -> ServerStrategy>, ServerConfig)> {
288 match self {
289 SourcePath::Direct(client_host) => {
290 let (conn_type, bind_type) =
291 server_host.strategy_as_server(client_host.deref(), network_hint)?;
292 let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
293 Ok((
294 Box::new(|host| ServerStrategy::Direct(bind_type(host))),
295 base_config,
296 ))
297 }
298
299 SourcePath::Many(client_host) => {
300 let (conn_type, bind_type) =
301 server_host.strategy_as_server(client_host.deref(), network_hint)?;
302 let base_config = ServerConfig::from_strategy(&conn_type, Arc::new(server.clone()));
303 Ok((
304 Box::new(|host| ServerStrategy::Many(bind_type(host))),
305 base_config,
306 ))
307 }
308
309 SourcePath::Tagged(underlying, tag) => {
310 let (bind_type, base_config) =
311 underlying.plan(server, server_host, network_hint)?;
312 let tag = *tag;
313 Ok((
314 Box::new(move |host| ServerStrategy::Tagged(Box::new(bind_type(host)), tag)),
315 ServerConfig::TaggedUnwrap(Box::new(base_config)),
316 ))
317 }
318
319 SourcePath::Null => Ok((Box::new(|_| ServerStrategy::Null), ServerConfig::Null)),
320 }
321 }
322}
323
324impl RustCrateSink for RustCratePortConfig {
325 fn instantiate(&self, client_path: &SourcePath) -> Result<Box<dyn FnOnce() -> ServerConfig>> {
326 let server = self.service.upgrade().unwrap();
327 let server_read = server.try_read().unwrap();
328
329 let server_host = server_read.on.clone();
330
331 let (bind_type, base_config) =
332 client_path.plan(self, server_host.deref(), self.network_hint)?;
333
334 let server = server.clone();
335 let merge = self.merge;
336 let port = self.port.clone();
337 Ok(Box::new(move || {
338 let mut server_write = server.try_write().unwrap();
339 let bind_type = (bind_type)(&*server_write.on);
340
341 if merge {
342 let merge_config = server_write
343 .port_to_bind
344 .entry(port.clone())
345 .or_insert(ServerStrategy::Merge(vec![]));
346 let merge_index = if let ServerStrategy::Merge(merge) = merge_config {
347 merge.push(bind_type);
348 merge.len() - 1
349 } else {
350 panic!("Expected a merge connection definition")
351 };
352
353 ServerConfig::MergeSelect(Box::new(base_config), merge_index)
354 } else {
355 assert!(!server_write.port_to_bind.contains_key(&port));
356 server_write.port_to_bind.insert(port.clone(), bind_type);
357 base_config
358 }
359 }))
360 }
361
362 fn instantiate_reverse(
363 &self,
364 server_host: &Arc<dyn Host>,
365 server_sink: Arc<dyn RustCrateServer>,
366 wrap_client_port: &dyn Fn(ServerConfig) -> ServerConfig,
367 ) -> Result<ReverseSinkInstantiator> {
368 if !matches!(self.network_hint, PortNetworkHint::Auto) {
369 bail!("Trying to form collection where I am the client, but I have server hint")
370 }
371
372 let client = self.service.upgrade().unwrap();
373 let client_read = client.try_read().unwrap();
374
375 let server_host = server_host.clone();
376
377 let (conn_type, bind_type) =
378 server_host.strategy_as_server(client_read.on.deref(), PortNetworkHint::Auto)?;
379 let client_port = wrap_client_port(ServerConfig::from_strategy(&conn_type, server_sink));
380
381 let client = client.clone();
382 let merge = self.merge;
383 let port = self.port.clone();
384 Ok(Box::new(move |_| {
385 let mut client_write = client.try_write().unwrap();
386
387 if merge {
388 let merge_config = client_write
389 .port_to_server
390 .entry(port.clone())
391 .or_insert(ServerConfig::Merge(vec![]));
392
393 if let ServerConfig::Merge(merge) = merge_config {
394 merge.push(client_port);
395 } else {
396 panic!()
397 };
398 } else {
399 assert!(!client_write.port_to_server.contains_key(&port));
400 client_write
401 .port_to_server
402 .insert(port.clone(), client_port);
403 };
404
405 ServerStrategy::Direct((bind_type)(&*client_write.on))
406 }))
407 }
408}
409
410#[derive(Clone, Debug)]
411pub enum ServerConfig {
412 Direct(Arc<dyn RustCrateServer>),
413 Forwarded(Arc<dyn RustCrateServer>),
414 Demux(HashMap<u32, ServerConfig>),
416 DemuxSelect(Box<ServerConfig>, u32),
418 Merge(Vec<ServerConfig>),
420 MergeSelect(Box<ServerConfig>, usize),
422 Tagged(Box<ServerConfig>, u32),
423 TaggedUnwrap(Box<ServerConfig>),
424 Null,
425}
426
427impl ServerConfig {
428 pub fn from_strategy(
429 strategy: &ClientStrategy,
430 server: Arc<dyn RustCrateServer>,
431 ) -> ServerConfig {
432 match strategy {
433 ClientStrategy::UnixSocket(_) | ClientStrategy::InternalTcpPort(_) => {
434 ServerConfig::Direct(server)
435 }
436 ClientStrategy::ForwardedTcpPort(_) => ServerConfig::Forwarded(server),
437 }
438 }
439}
440
441#[async_recursion]
442async fn forward_connection(conn: &ServerPort, target: &dyn LaunchedHost) -> ServerPort {
443 match conn {
444 ServerPort::UnixSocket(_) => panic!("Expected a TCP port to be forwarded"),
445 ServerPort::TcpPort(addr) => ServerPort::TcpPort(target.forward_port(addr).await.unwrap()),
446 ServerPort::Demux(demux) => {
447 let mut forwarded_map = HashMap::new();
448 for (key, conn) in demux {
449 forwarded_map.insert(*key, forward_connection(conn, target).await);
450 }
451 ServerPort::Demux(forwarded_map)
452 }
453 ServerPort::Merge(merge) => {
454 let mut forwarded_vec = Vec::new();
455 for conn in merge {
456 forwarded_vec.push(forward_connection(conn, target).await);
457 }
458 ServerPort::Merge(forwarded_vec)
459 }
460 ServerPort::Tagged(underlying, id) => {
461 ServerPort::Tagged(Box::new(forward_connection(underlying, target).await), *id)
462 }
463 ServerPort::Null => ServerPort::Null,
464 }
465}
466
467impl ServerConfig {
468 #[async_recursion]
469 pub async fn load_instantiated(
470 &self,
471 select: &(dyn Fn(ServerPort) -> ServerPort + Send + Sync),
472 ) -> ServerPort {
473 match self {
474 ServerConfig::Direct(server) => select(server.get_port()),
475
476 ServerConfig::Forwarded(server) => {
477 let selected = select(server.get_port());
478 forward_connection(&selected, server.launched_host().as_ref()).await
479 }
480
481 ServerConfig::Demux(demux) => {
482 let mut demux_map = HashMap::new();
483 for (key, conn) in demux {
484 demux_map.insert(*key, conn.load_instantiated(select).await);
485 }
486 ServerPort::Demux(demux_map)
487 }
488
489 ServerConfig::DemuxSelect(underlying, key) => {
490 let key = *key;
491 underlying
492 .load_instantiated(
493 &(move |p| {
494 if let ServerPort::Demux(mut mapping) = p {
495 select(mapping.remove(&key).unwrap())
496 } else {
497 panic!("Expected a demux connection definition")
498 }
499 }),
500 )
501 .await
502 }
503
504 ServerConfig::Merge(merge) => {
505 let mut merge_vec = Vec::new();
506 for conn in merge {
507 merge_vec.push(conn.load_instantiated(select).await);
508 }
509 ServerPort::Merge(merge_vec)
510 }
511
512 ServerConfig::MergeSelect(underlying, key) => {
513 let key = *key;
514 underlying
515 .load_instantiated(
516 &(move |p| {
517 if let ServerPort::Merge(mut mapping) = p {
518 select(mapping.remove(key))
519 } else {
520 panic!("Expected a merge connection definition")
521 }
522 }),
523 )
524 .await
525 }
526
527 ServerConfig::Tagged(underlying, id) => {
528 ServerPort::Tagged(Box::new(underlying.load_instantiated(select).await), *id)
529 }
530
531 ServerConfig::TaggedUnwrap(underlying) => {
532 let loaded = underlying.load_instantiated(select).await;
533 if let ServerPort::Tagged(underlying, _) = loaded {
534 *underlying
535 } else {
536 panic!("Expected a tagged connection definition")
537 }
538 }
539
540 ServerConfig::Null => ServerPort::Null,
541 }
542 }
543}