parallel.rs (5542B)
1 use super::ProblemSolver; 2 use std::ops::{Deref, DerefMut}; 3 4 use futures::ready; 5 use std::future::Future; 6 use std::pin::Pin; 7 8 pub trait AsyncTester { 9 type Result: Future<Output = Vec<bool>>; 10 11 fn test_async(&self, query: Vec<(usize, usize)>) -> Self::Result; 12 } 13 14 pub struct ParallelProblemSolver<T> 15 where 16 T: AsyncTester, 17 { 18 solver: ProblemSolver, 19 current_test: Option<(T::Result, Vec<usize>)>, 20 } 21 22 impl<T: AsyncTester> Deref for ParallelProblemSolver<T> { 23 type Target = ProblemSolver; 24 25 fn deref(&self) -> &Self::Target { 26 &self.solver 27 } 28 } 29 30 impl<T: AsyncTester> DerefMut for ParallelProblemSolver<T> { 31 fn deref_mut(&mut self) -> &mut Self::Target { 32 &mut self.solver 33 } 34 } 35 36 impl<T: AsyncTester> ParallelProblemSolver<T> { 37 pub fn new(width: usize, depth: usize) -> Self { 38 Self { 39 solver: ProblemSolver::new(width, depth), 40 current_test: None, 41 } 42 } 43 } 44 45 type TestQuery = (Vec<(usize, usize)>, Vec<usize>); 46 47 impl<T: AsyncTester> ParallelProblemSolver<T> { 48 pub fn try_generate_complete_candidate(&mut self) -> bool { 49 while !self.is_complete() { 50 while self.is_current_cell_missing() { 51 if !self.try_advance_source() { 52 return false; 53 } 54 } 55 if !self.try_advance_resource() { 56 return false; 57 } 58 } 59 true 60 } 61 62 fn try_generate_test_query(&mut self) -> Result<TestQuery, usize> { 63 let mut test_cells = vec![]; 64 let query = self 65 .solution 66 .iter() 67 .enumerate() 68 .filter_map(|(res_idx, source_idx)| { 69 let cell = self.cache[res_idx][*source_idx]; 70 match cell { 71 None => { 72 test_cells.push(res_idx); 73 Some(Ok((res_idx, *source_idx))) 74 } 75 Some(false) => Some(Err(res_idx)), 76 Some(true) => None, 77 } 78 }) 79 .collect::<Result<_, _>>()?; 80 Ok((query, test_cells)) 81 } 82 83 fn apply_test_result( 84 &mut self, 85 resources: Vec<bool>, 86 testing_cells: Vec<usize>, 87 ) -> Result<(), usize> { 88 let mut first_missing = None; 89 for (result, res_idx) in resources.into_iter().zip(testing_cells) { 90 let source_idx = self.solution[res_idx]; 91 self.cache[res_idx][source_idx] = Some(result); 92 if !result && first_missing.is_none() { 93 first_missing = Some(res_idx); 94 } 95 } 96 if let Some(idx) = first_missing { 97 Err(idx) 98 } else { 99 Ok(()) 100 } 101 } 102 103 pub fn try_poll_next( 104 mut self: std::pin::Pin<&mut Self>, 105 cx: &mut std::task::Context<'_>, 106 tester: &T, 107 prefetch: bool, 108 ) -> std::task::Poll<Result<Option<Vec<usize>>, usize>> 109 where 110 <T as AsyncTester>::Result: Unpin, 111 { 112 if self.width == 0 || self.depth == 0 { 113 return Ok(None).into(); 114 } 115 116 'outer: loop { 117 if let Some((test, testing_cells)) = &mut self.current_test { 118 let pinned = Pin::new(test); 119 let set = ready!(pinned.poll(cx)); 120 let testing_cells = testing_cells.clone(); 121 122 if let Err(res_idx) = self.apply_test_result(set, testing_cells) { 123 self.idx = res_idx; 124 self.prune(); 125 if !self.bail() { 126 if let Some(res_idx) = self.has_missing_cell() { 127 return Err(res_idx).into(); 128 } else { 129 return Ok(None).into(); 130 } 131 } 132 self.current_test = None; 133 continue 'outer; 134 } else { 135 self.current_test = None; 136 if !prefetch { 137 self.dirty = true; 138 } 139 return Ok(Some(self.solution.clone())).into(); 140 } 141 } else { 142 if self.dirty { 143 if !self.bail() { 144 if let Some(res_idx) = self.has_missing_cell() { 145 return Err(res_idx).into(); 146 } else { 147 return Ok(None).into(); 148 } 149 } 150 self.dirty = false; 151 } 152 while self.try_generate_complete_candidate() { 153 match self.try_generate_test_query() { 154 Ok((query, testing_cells)) => { 155 self.current_test = Some((tester.test_async(query), testing_cells)); 156 continue 'outer; 157 } 158 Err(res_idx) => { 159 self.idx = res_idx; 160 self.prune(); 161 if !self.bail() { 162 if let Some(res_idx) = self.has_missing_cell() { 163 return Err(res_idx).into(); 164 } else { 165 return Ok(None).into(); 166 } 167 } 168 } 169 } 170 } 171 return Ok(None).into(); 172 } 173 } 174 } 175 }