borrow/
lib.rs

1use std::{marker::PhantomData, ptr::NonNull};
2
3/// ```compile_fail
4/// struct Base {
5///     buf: Vec<String>,
6///     len: usize, // some other data
7/// }
8///
9/// enum Entry<'a> {
10///     Occupied(OccupiedEntry<'a>),
11///     Vacant(VacantEntry),
12/// }
13/// use Entry::Occupied;
14///
15/// struct OccupiedEntry<'a> {
16///     handle: &'a mut String,
17///     base: &'a mut Base,
18/// }
19/// struct VacantEntry;
20///
21/// impl Base {
22///     pub fn new() -> Self { Self { buf: vec![], len: 0 } }
23///     pub fn entry(&mut self, key: usize) -> Entry {
24///         match self.buf.get_mut(key) {
25///             Some(handle) => Occupied(OccupiedEntry { base: self, handle }),
26///             None => unimplemented!(),
27///         }
28///     }
29/// }
30/// ```
31///
32/// ```text
33/// error[E0499]: cannot borrow `*self` as mutable more than once at a time
34///   --> src/lib.rs:26:60
35///    |
36/// 22 |     pub fn entry(&mut self, key: usize) -> Entry {
37///    |                  - let's call the lifetime of this reference `'1`
38/// 23 |         match self.buf.get_mut(key) {
39///    |               --------------------- first mutable borrow occurs here
40/// 24 |             Some(handle) => Occupied(OccupiedEntry { base: self, handle }),
41///    |                             -------------------------------^^^^-----------
42///    |                             |                              |
43///    |                             |                              second mutable borrow occurs here
44///    |                             returning this value requires that `self.buf` is borrowed for `'1`
45/// ```
46///
47///
48/// ## References
49/// - <https://www.reddit.com/r/rust/comments/11f45re>
50/// - <https://doc.rust-lang.org/nomicon/lifetime-mismatch.html#improperly-reduced-borrows>
51pub struct DormantMutRef<'a, T> {
52    ptr: NonNull<T>,
53    _marker: PhantomData<&'a mut T>,
54}
55
56unsafe impl<'a, T> Sync for DormantMutRef<'a, T> where &'a mut T: Sync {}
57unsafe impl<'a, T> Send for DormantMutRef<'a, T> where &'a mut T: Send {}
58
59impl<'a, T> DormantMutRef<'a, T> {
60    pub fn new(t: &'a mut T) -> (&'a mut T, Self) {
61        let ptr = NonNull::from(t);
62        let new_ref = unsafe { &mut *ptr.as_ptr() };
63        (new_ref, Self { ptr, _marker: PhantomData })
64    }
65    pub unsafe fn awaken(self) -> &'a mut T {
66        unsafe { &mut *self.ptr.as_ptr() }
67    }
68    pub unsafe fn reborrow(&mut self) -> &'a mut T {
69        unsafe { &mut *self.ptr.as_ptr() }
70    }
71    pub unsafe fn reborrow_shared(&self) -> &'a T {
72        unsafe { &*self.ptr.as_ptr() }
73    }
74}
75
76#[test]
77fn sanity_check() {
78    let mut x = 0;
79    let y = {
80        let (x, dormant_x) = DormantMutRef::new(&mut x);
81        *x += 1;
82        unsafe { dormant_x.awaken() }
83    };
84    *y += 1;
85    assert_eq!(x, 2);
86}
87
88#[cfg(test)]
89mod tests {
90    use crate::DormantMutRef;
91
92    struct Base {
93        buf: Vec<Option<String>>,
94        len: usize,
95    }
96
97    enum Entry<'a> {
98        Occupied(OccupiedEntry<'a>),
99        Vacant(VacantEntry<'a>),
100    }
101    use Entry::{Occupied, Vacant};
102
103    struct OccupiedEntry<'a> {
104        key: usize,
105        handle: &'a mut String,
106        dormant_base: DormantMutRef<'a, Base>,
107    }
108
109    struct VacantEntry<'a> {
110        key: usize,
111        handle: Option<&'a mut Option<String>>,
112        dormant_base: DormantMutRef<'a, Base>,
113    }
114
115    impl Base {
116        pub fn new() -> Self { Self { buf: vec![], len: 0 } }
117        pub fn entry(&mut self, key: usize) -> Entry {
118            let (self_, dormant_self) = DormantMutRef::new(self);
119            match self_.buf.get_mut(key) {
120                Some(Some(v)) => Occupied(OccupiedEntry {
121                    key,
122                    handle: v,
123                    dormant_base: dormant_self,
124                }),
125                Some(v) => Vacant(VacantEntry {
126                    key,
127                    handle: Some(v),
128                    dormant_base: dormant_self,
129                }),
130                None => Vacant(VacantEntry {
131                    key,
132                    handle: None,
133                    dormant_base: dormant_self,
134                }),
135            }
136        }
137    }
138
139    impl<'a> Entry<'a> {
140        pub fn and_modify<F>(self, f: F) -> Self
141        where
142            F: FnOnce(&mut String),
143        {
144            match self {
145                Occupied(mut entry) => {
146                    f(entry.get_mut());
147                    Occupied(entry)
148                }
149                Vacant(entry) => Vacant(entry),
150            }
151        }
152        pub fn key(&self) -> usize {
153            match *self {
154                Occupied(ref entry) => entry.key(),
155                Vacant(ref entry) => entry.key(),
156            }
157        }
158        pub fn or_default(self) -> &'a mut String {
159            match self {
160                Occupied(entry) => entry.into_mut(),
161                Vacant(entry) => entry.insert(Default::default()),
162            }
163        }
164        pub fn or_insert(self, default: String) -> &'a mut String {
165            match self {
166                Occupied(entry) => entry.into_mut(),
167                Vacant(entry) => entry.insert(default),
168            }
169        }
170        pub fn or_insert_with<F: FnOnce() -> String>(
171            self,
172            default: F,
173        ) -> &'a mut String {
174            match self {
175                Occupied(entry) => entry.into_mut(),
176                Vacant(entry) => entry.insert(default()),
177            }
178        }
179        pub fn or_insert_with_key<F: FnOnce(usize) -> String>(
180            self,
181            default: F,
182        ) -> &'a mut String {
183            match self {
184                Occupied(entry) => entry.into_mut(),
185                Vacant(entry) => {
186                    let value = default(entry.key());
187                    entry.insert(value)
188                }
189            }
190        }
191    }
192
193    impl<'a> OccupiedEntry<'a> {
194        pub fn get(&self) -> &String { &*self.handle }
195        pub fn get_mut(&mut self) -> &mut String { self.handle }
196        pub fn insert(&mut self, value: String) -> String {
197            std::mem::replace(self.handle, value)
198        }
199        pub fn into_mut(self) -> &'a mut String { self.handle }
200        pub fn key(&self) -> usize { self.key }
201        pub fn remove(self) -> String { self.remove_entry().1 }
202        pub fn remove_entry(self) -> (usize, String) {
203            let res = std::mem::take(self.handle);
204            unsafe { self.dormant_base.awaken() }.len -= 1;
205            (self.key, res)
206        }
207    }
208
209    impl<'a> VacantEntry<'a> {
210        pub fn key(&self) -> usize { self.key }
211        pub fn insert(self, value: String) -> &'a mut String {
212            match self.handle {
213                None => {
214                    let key = self.key;
215                    let base = unsafe { self.dormant_base.awaken() };
216                    base.buf.resize_with(key, || None);
217                    base.buf.push(Some(value));
218                    base.len += 1;
219                    base.buf.last_mut().unwrap().as_mut().unwrap()
220                }
221                Some(handle) => {
222                    let was_none = handle.is_none();
223                    let new = handle.insert(value);
224                    let base = unsafe { self.dormant_base.awaken() };
225                    if was_none {
226                        base.len += 1;
227                    }
228                    new
229                }
230            }
231        }
232    }
233
234    #[test]
235    fn entry() {
236        let mut base = Base::new();
237
238        assert_eq!(base.entry(0).key(), 0);
239
240        base.entry(0).or_insert("zero".to_owned());
241        assert_eq!(base.buf[0].as_ref().unwrap(), "zero");
242        assert_eq!(base.len, 1);
243
244        base.entry(0).or_insert_with(|| "xxx".to_owned());
245        assert_eq!(base.buf[0].as_ref().unwrap(), "zero");
246        assert_eq!(base.len, 1);
247
248        base.entry(2).or_insert_with_key(|_| "two".to_owned());
249        assert!(base.buf[1].is_none());
250        assert_eq!(base.buf[2].as_ref().unwrap(), "two");
251        assert_eq!(base.len, 2);
252
253        base.entry(2).and_modify(|v| *v = "second".to_owned());
254        assert_eq!(base.len, 2);
255
256        if let Occupied(o) = base.entry(2) {
257            assert_eq!(o.get(), "second");
258            assert_eq!(o.remove(), "second");
259            assert_eq!(base.len, 1);
260        }
261
262        base.entry(1).or_default();
263        assert_eq!(base.len, 2);
264        assert!(base.buf[1].as_ref().unwrap().is_empty());
265        if let Occupied(mut o) = base.entry(1) {
266            o.insert("first".to_owned());
267            assert_eq!(o.get(), "first");
268        }
269    }
270}