ALPHA: wire is a tool to deploy nixos systems
wire.althaea.zone/
1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright 2024-2025 wire Contributors
3
4use futures::{FutureExt, StreamExt};
5use itertools::{Either, Itertools};
6use miette::{Diagnostic, IntoDiagnostic, Result};
7use std::any::Any;
8use std::collections::HashSet;
9use std::io::{Read, stderr};
10use std::sync::Arc;
11use std::sync::atomic::AtomicBool;
12use thiserror::Error;
13use tracing::{error, info};
14use wire_core::hive::node::{Context, GoalExecutor, Name, Node, Objective, StepState};
15use wire_core::hive::{Hive, HiveLocation};
16use wire_core::status::STATUS;
17use wire_core::{SubCommandModifiers, errors::HiveLibError};
18
19use crate::cli::{ApplyTarget, CommonVerbArgs, Partitions};
20
21#[derive(Debug, Error, Diagnostic)]
22#[error("node {} failed to apply", .0)]
23struct NodeError(
24 Name,
25 #[source]
26 #[diagnostic_source]
27 HiveLibError,
28);
29
30#[derive(Debug, Error, Diagnostic)]
31#[error("{} node(s) failed to apply.", .0.len())]
32struct NodeErrors(#[related] Vec<NodeError>);
33
34// returns Names and Tags
35fn read_apply_targets_from_stdin() -> Result<(Vec<String>, Vec<Name>)> {
36 let mut buf = String::new();
37 let mut stdin = std::io::stdin().lock();
38 stdin.read_to_string(&mut buf).into_diagnostic()?;
39
40 Ok(buf
41 .split_whitespace()
42 .map(|x| ApplyTarget::from(x.to_string()))
43 .fold((Vec::new(), Vec::new()), |(mut tags, mut names), target| {
44 match target {
45 ApplyTarget::Node(name) => names.push(name),
46 ApplyTarget::Tag(tag) => tags.push(tag),
47 ApplyTarget::Stdin => {}
48 }
49 (tags, names)
50 }))
51}
52
53fn resolve_targets(
54 on: &[ApplyTarget],
55 modifiers: &mut SubCommandModifiers,
56) -> (HashSet<String>, HashSet<Name>) {
57 on.iter().fold(
58 (HashSet::new(), HashSet::new()),
59 |(mut tags, mut names), target| {
60 match target {
61 ApplyTarget::Tag(tag) => {
62 tags.insert(tag.clone());
63 }
64 ApplyTarget::Node(name) => {
65 names.insert(name.clone());
66 }
67 ApplyTarget::Stdin => {
68 // implies non_interactive
69 modifiers.non_interactive = true;
70
71 let (found_tags, found_names) = read_apply_targets_from_stdin().unwrap();
72 names.extend(found_names);
73 tags.extend(found_tags);
74 }
75 }
76 (tags, names)
77 },
78 )
79}
80
81fn partition_arr<T>(arr: Vec<T>, partition: &Partitions) -> Vec<T>
82where
83 T: Any + Clone,
84{
85 if arr.is_empty() {
86 return arr;
87 }
88
89 let items_per_chunk = arr.len().div_ceil(partition.maximum);
90
91 arr.chunks(items_per_chunk)
92 .nth(partition.current - 1)
93 .unwrap_or(&[])
94 .to_vec()
95}
96
97pub async fn apply<F>(
98 hive: &mut Hive,
99 should_shutdown: Arc<AtomicBool>,
100 location: HiveLocation,
101 args: CommonVerbArgs,
102 partition: Partitions,
103 make_objective: F,
104 mut modifiers: SubCommandModifiers,
105) -> Result<()>
106where
107 F: Fn(&Name, &Node) -> Objective,
108{
109 let location = Arc::new(location);
110
111 let (tags, names) = resolve_targets(&args.on, &mut modifiers);
112
113 let selected_names: Vec<_> = hive
114 .nodes
115 .iter()
116 .filter(|(name, node)| {
117 args.on.is_empty()
118 || names.contains(name)
119 || node.tags.iter().any(|tag| tags.contains(tag))
120 })
121 .sorted_by_key(|(name, _)| *name)
122 .map(|(name, _)| name.clone())
123 .collect();
124
125 let num_selected = selected_names.len();
126
127 let partitioned_names = partition_arr(selected_names, &partition);
128
129 if num_selected != partitioned_names.len() {
130 info!(
131 "Partitioning reduced selected number of nodes from {num_selected} to {}",
132 partitioned_names.len()
133 );
134 }
135
136 STATUS
137 .lock()
138 .add_many(&partitioned_names.iter().collect::<Vec<_>>());
139
140 let mut set = hive
141 .nodes
142 .iter_mut()
143 .filter(|(name, _)| partitioned_names.contains(name))
144 .map(|(name, node)| {
145 info!("Resolved {:?} to include {}", args.on, name);
146
147 let objective = make_objective(name, node);
148
149 let context = Context {
150 node,
151 name,
152 objective,
153 state: StepState::default(),
154 hive_location: location.clone(),
155 modifiers,
156 should_quit: should_shutdown.clone(),
157 };
158
159 GoalExecutor::new(context)
160 .execute()
161 .map(move |result| (name, result))
162 })
163 .peekable();
164
165 if set.peek().is_none() {
166 error!("There are no nodes selected for deployment");
167 }
168
169 let futures = futures::stream::iter(set).buffer_unordered(args.parallel);
170 let result = futures.collect::<Vec<_>>().await;
171 let (successful, errors): (Vec<_>, Vec<_>) =
172 result
173 .into_iter()
174 .partition_map(|(name, result)| match result {
175 Ok(..) => Either::Left(name),
176 Err(err) => Either::Right((name, err)),
177 });
178
179 if !successful.is_empty() {
180 info!(
181 "Successfully applied goal to {} node(s): {:?}",
182 successful.len(),
183 successful
184 );
185 }
186
187 if !errors.is_empty() {
188 // clear the status bar if we are about to print error messages
189 STATUS.lock().clear(&mut stderr());
190
191 return Err(NodeErrors(
192 errors
193 .into_iter()
194 .map(|(name, error)| NodeError(name.clone(), error))
195 .collect(),
196 )
197 .into());
198 }
199
200 Ok(())
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 #[allow(clippy::too_many_lines)]
209 fn test_partitioning() {
210 let arr = (1..=10).collect::<Vec<_>>();
211 assert_eq!(arr, partition_arr(arr.clone(), &Partitions::default()));
212
213 assert_eq!(
214 vec![1, 2, 3, 4, 5],
215 partition_arr(
216 arr.clone(),
217 &Partitions {
218 current: 1,
219 maximum: 2
220 }
221 )
222 );
223 assert_eq!(
224 vec![6, 7, 8, 9, 10],
225 partition_arr(
226 arr,
227 &Partitions {
228 current: 2,
229 maximum: 2
230 }
231 )
232 );
233
234 // test odd number
235 let arr = (1..10).collect::<Vec<_>>();
236 assert_eq!(
237 arr.clone(),
238 partition_arr(arr.clone(), &Partitions::default())
239 );
240
241 assert_eq!(
242 vec![1, 2, 3, 4, 5],
243 partition_arr(
244 arr.clone(),
245 &Partitions {
246 current: 1,
247 maximum: 2
248 }
249 )
250 );
251 assert_eq!(
252 vec![6, 7, 8, 9],
253 partition_arr(
254 arr.clone(),
255 &Partitions {
256 current: 2,
257 maximum: 2
258 }
259 )
260 );
261
262 // test large number of partitions
263 let arr = (1..=10).collect::<Vec<_>>();
264 assert_eq!(
265 arr.clone(),
266 partition_arr(arr.clone(), &Partitions::default())
267 );
268
269 for i in 1..=10 {
270 assert_eq!(
271 vec![i],
272 partition_arr(
273 arr.clone(),
274 &Partitions {
275 current: i,
276 maximum: 10
277 }
278 )
279 );
280
281 assert_eq!(
282 vec![i],
283 partition_arr(
284 arr.clone(),
285 &Partitions {
286 current: i,
287 maximum: 15
288 }
289 )
290 );
291 }
292
293 // stretching thin with higher partitions will start to leave higher ones empty
294 assert_eq!(
295 Vec::<usize>::new(),
296 partition_arr(
297 arr,
298 &Partitions {
299 current: 11,
300 maximum: 15
301 }
302 )
303 );
304
305 // test the above holds for a lot of numbers
306 for i in 1..1000 {
307 let arr: Vec<usize> = (0..i).collect();
308 let total = arr.len();
309
310 assert_eq!(
311 arr.clone(),
312 partition_arr(arr.clone(), &Partitions::default()),
313 );
314
315 let buckets = 2;
316 let chunk_size = total.div_ceil(buckets);
317 let split_index = std::cmp::min(chunk_size, total);
318
319 assert_eq!(
320 &arr.clone()[..split_index],
321 partition_arr(
322 arr.clone(),
323 &Partitions {
324 current: 1,
325 maximum: 2
326 }
327 ),
328 );
329 assert_eq!(
330 &arr.clone()[split_index..],
331 partition_arr(
332 arr.clone(),
333 &Partitions {
334 current: 2,
335 maximum: 2
336 }
337 ),
338 );
339 }
340 }
341}