aboutsummaryrefslogtreecommitdiff
path: root/src/toposort.rs
blob: 5e4459046d104462ef650a4c38df165df139177e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::collections::HashSet;
use std::hash::Hash;

#[derive(Debug, thiserror::Error)]
#[error("Cycle detected")]
pub struct TopoSortCycle;

pub fn toposort<K, I>(
    input: impl IntoIterator<Item = K>,
    deps: impl Fn(&K) -> I,
) -> Result<Vec<K>, TopoSortCycle>
where
    K: Eq + Hash + Clone,
    I: Iterator<Item = K>,
{
    struct Ts<K, D, I>
    where
        K: Eq + Hash + Clone,
        I: Iterator<Item = K>,
        D: Fn(&K) -> I,
    {
        result_set: HashSet<K>,
        result: Vec<K>,
        deps: D,
        stack: HashSet<K>,
    }

    impl<K, D, I> Ts<K, D, I>
    where
        K: Eq + Hash + Clone,
        I: Iterator<Item = K>,
        D: Fn(&K) -> I,
    {
        fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> {
            if self.result_set.contains(i) {
                return Ok(());
            }

            if !self.stack.insert(i.clone()) {
                return Err(TopoSortCycle);
            }
            for dep in (self.deps)(i) {
                self.visit(&dep)?;
            }

            let removed = self.stack.remove(i);
            assert!(removed);

            self.result.push(i.clone());
            self.result_set.insert(i.clone());

            Ok(())
        }
    }

    let mut ts = Ts {
        result: Vec::new(),
        result_set: HashSet::new(),
        deps,
        stack: HashSet::new(),
    };

    for i in input {
        ts.visit(&i)?;
    }

    Ok(ts.result)
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use crate::toposort::toposort;
    use crate::toposort::TopoSortCycle;

    fn test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle> {
        let mut keys: Vec<&str> = Vec::new();
        let mut edges: HashMap<&str, Vec<&str>> = HashMap::new();
        for part in input.split(" ") {
            match part.split_once("->") {
                Some((k, vs)) => {
                    keys.push(k);
                    edges.insert(k, vs.split(",").collect());
                }
                None => keys.push(part),
            };
        }

        toposort(keys, |k| {
            edges
                .get(k)
                .map(|v| v.as_slice())
                .unwrap_or_default()
                .into_iter()
                .copied()
        })
    }

    fn test_toposort_check(input: &str, expected: &str) {
        let sorted = test_toposort(input).unwrap();
        let expected = expected.split(" ").collect::<Vec<_>>();
        assert_eq!(expected, sorted);
    }

    #[test]
    fn test() {
        test_toposort_check("1 2 3", "1 2 3");
        test_toposort_check("1->2 2->3 3", "3 2 1");
        test_toposort_check("1 2->1 3->2", "1 2 3");
        test_toposort_check("1->2,3 2->3 3", "3 2 1");
    }

    #[test]
    fn cycle() {
        assert!(test_toposort("1->1").is_err());
        assert!(test_toposort("1->2 2->1").is_err());
    }
}