core/slice/
rotate.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::{cmp, ptr};

type BufType = [usize; 32];

/// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
/// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
/// right.
///
/// # Safety
///
/// The specified range must be valid for reading and writing.
#[inline]
pub(super) unsafe fn ptr_rotate<T>(left: usize, mid: *mut T, right: usize) {
    if T::IS_ZST {
        return;
    }
    // abort early if the rotate is a no-op
    if (left == 0) || (right == 0) {
        return;
    }
    // `T` is not a zero-sized type, so it's okay to divide by its size.
    if !cfg!(feature = "optimize_for_size")
        && cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
    {
        // SAFETY: guaranteed by the caller
        unsafe { ptr_rotate_memmove(left, mid, right) };
    } else if !cfg!(feature = "optimize_for_size")
        && ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
    {
        // SAFETY: guaranteed by the caller
        unsafe { ptr_rotate_gcd(left, mid, right) }
    } else {
        // SAFETY: guaranteed by the caller
        unsafe { ptr_rotate_swap(left, mid, right) }
    }
}

/// Algorithm 1 is used if `min(left, right)` is small enough to fit onto a stack buffer. The
/// `min(left, right)` elements are copied onto the buffer, `memmove` is applied to the others, and
/// the ones on the buffer are moved back into the hole on the opposite side of where they
/// originated.
///
/// # Safety
///
/// The specified range must be valid for reading and writing.
#[inline]
unsafe fn ptr_rotate_memmove<T>(left: usize, mid: *mut T, right: usize) {
    // The `[T; 0]` here is to ensure this is appropriately aligned for T
    let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
    let buf = rawarray.as_mut_ptr() as *mut T;
    // SAFETY: `mid-left <= mid-left+right < mid+right`
    let dim = unsafe { mid.sub(left).add(right) };
    if left <= right {
        // SAFETY:
        //
        // 1) The `if` condition about the sizes ensures `[mid-left; left]` will fit in
        //    `buf` without overflow and `buf` was created just above and so cannot be
        //    overlapped with any value of `[mid-left; left]`
        // 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
        //    about overlaps here.
        // 3) The `if` condition about `left <= right` ensures writing `left` elements to
        //    `dim = mid-left+right` is valid because:
        //    - `buf` is valid and `left` elements were written in it in 1)
        //    - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
        unsafe {
            // 1)
            ptr::copy_nonoverlapping(mid.sub(left), buf, left);
            // 2)
            ptr::copy(mid, mid.sub(left), right);
            // 3)
            ptr::copy_nonoverlapping(buf, dim, left);
        }
    } else {
        // SAFETY: same reasoning as above but with `left` and `right` reversed
        unsafe {
            ptr::copy_nonoverlapping(mid, buf, right);
            ptr::copy(mid.sub(left), dim, left);
            ptr::copy_nonoverlapping(buf, mid.sub(left), right);
        }
    }
}

/// Algorithm 2 is used for small values of `left + right` or for large `T`. The elements
/// are moved into their final positions one at a time starting at `mid - left` and advancing by
/// `right` steps modulo `left + right`, such that only one temporary is needed. Eventually, we
/// arrive back at `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps
/// skipped over elements. For example:
/// ```text
/// left = 10, right = 6
/// the `^` indicates an element in its final place
/// 6 7 8 9 10 11 12 13 14 15 . 0 1 2 3 4 5
/// after using one step of the above algorithm (The X will be overwritten at the end of the round,
/// and 12 is stored in a temporary):
/// X 7 8 9 10 11 6 13 14 15 . 0 1 2 3 4 5
///               ^
/// after using another step (now 2 is in the temporary):
/// X 7 8 9 10 11 6 13 14 15 . 0 1 12 3 4 5
///               ^                 ^
/// after the third step (the steps wrap around, and 8 is in the temporary):
/// X 7 2 9 10 11 6 13 14 15 . 0 1 12 3 4 5
///     ^         ^                 ^
/// after 7 more steps, the round ends with the temporary 0 getting put in the X:
/// 0 7 2 9 4 11 6 13 8 15 . 10 1 12 3 14 5
/// ^   ^   ^    ^    ^       ^    ^    ^
/// ```
/// Fortunately, the number of skipped over elements between finalized elements is always equal, so
/// we can just offset our starting position and do more rounds (the total number of rounds is the
/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
/// only once.
///
/// Algorithm 2 can be vectorized by chunking and performing many rounds at once, but there are too
/// few rounds on average until `left + right` is enormous, and the worst case of a single
/// round is always there.
///
/// # Safety
///
/// The specified range must be valid for reading and writing.
#[inline]
unsafe fn ptr_rotate_gcd<T>(left: usize, mid: *mut T, right: usize) {
    // Algorithm 2
    // Microbenchmarks indicate that the average performance for random shifts is better all
    // the way until about `left + right == 32`, but the worst case performance breaks even
    // around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
    // `usize`s, this algorithm also outperforms other algorithms.
    // SAFETY: callers must ensure `mid - left` is valid for reading and writing.
    let x = unsafe { mid.sub(left) };
    // beginning of first round
    // SAFETY: see previous comment.
    let mut tmp: T = unsafe { x.read() };
    let mut i = right;
    // `gcd` can be found before hand by calculating `gcd(left + right, right)`,
    // but it is faster to do one loop which calculates the gcd as a side effect, then
    // doing the rest of the chunk
    let mut gcd = right;
    // benchmarks reveal that it is faster to swap temporaries all the way through instead
    // of reading one temporary once, copying backwards, and then writing that temporary at
    // the very end. This is possibly due to the fact that swapping or replacing temporaries
    // uses only one memory address in the loop instead of needing to manage two.
    loop {
        // [long-safety-expl]
        // SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
        // writing.
        //
        // - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
        // - `i <= left+right-1` is always true
        //   - if `i < left`, `right` is added so `i < left+right` and on the next
        //     iteration `left` is removed from `i` so it doesn't go further
        //   - if `i >= left`, `left` is removed immediately and so it doesn't go further.
        // - overflows cannot happen for `i` since the function's safety contract ask for
        //   `mid+right-1 = x+left+right` to be valid for writing
        // - underflows cannot happen because `i` must be bigger or equal to `left` for
        //   a subtraction of `left` to happen.
        //
        // So `x+i` is valid for reading and writing if the caller respected the contract
        tmp = unsafe { x.add(i).replace(tmp) };
        // instead of incrementing `i` and then checking if it is outside the bounds, we
        // check if `i` will go outside the bounds on the next increment. This prevents
        // any wrapping of pointers or `usize`.
        if i >= left {
            i -= left;
            if i == 0 {
                // end of first round
                // SAFETY: tmp has been read from a valid source and x is valid for writing
                // according to the caller.
                unsafe { x.write(tmp) };
                break;
            }
            // this conditional must be here if `left + right >= 15`
            if i < gcd {
                gcd = i;
            }
        } else {
            i += right;
        }
    }
    // finish the chunk with more rounds
    for start in 1..gcd {
        // SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
        // reading and writing as per the function's safety contract, see [long-safety-expl]
        // above
        tmp = unsafe { x.add(start).read() };
        // [safety-expl-addition]
        //
        // Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
        // greatest common divisor of `(left+right, right)` means that `left = right` so
        // `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
        // according to the function's safety contract.
        i = start + right;
        loop {
            // SAFETY: see [long-safety-expl] and [safety-expl-addition]
            tmp = unsafe { x.add(i).replace(tmp) };
            if i >= left {
                i -= left;
                if i == start {
                    // SAFETY: see [long-safety-expl] and [safety-expl-addition]
                    unsafe { x.add(start).write(tmp) };
                    break;
                }
            } else {
                i += right;
            }
        }
    }
}

/// Algorithm 3 utilizes repeated swapping of `min(left, right)` elements.
///
/// ///
/// ```text
/// left = 11, right = 4
/// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
///                  ^  ^  ^  ^   ^ ^ ^ ^ swapping the right most elements with elements to the left
/// [4 5 6 7 8 9 10 . 0 1 2 3] 11 12 13 14
///        ^ ^ ^  ^   ^ ^ ^ ^ swapping these
/// [4 5 6 . 0 1 2 3] 7 8 9 10 11 12 13 14
/// we cannot swap any more, but a smaller rotation problem is left to solve
/// ```
/// when `left < right` the swapping happens from the left instead.
///
/// # Safety
///
/// The specified range must be valid for reading and writing.
#[inline]
unsafe fn ptr_rotate_swap<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
    loop {
        if left >= right {
            // Algorithm 3
            // There is an alternate way of swapping that involves finding where the last swap
            // of this algorithm would be, and swapping using that last chunk instead of swapping
            // adjacent chunks like this algorithm is doing, but this way is still faster.
            loop {
                // SAFETY:
                // `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
                // Subtracting `right` from `mid` each turn is counterbalanced by the addition and
                // check after it.
                unsafe {
                    ptr::swap_nonoverlapping(mid.sub(right), mid, right);
                    mid = mid.sub(right);
                }
                left -= right;
                if left < right {
                    break;
                }
            }
        } else {
            // Algorithm 3, `left < right`
            loop {
                // SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
                // `left < right` so `mid+left < mid+right`.
                // Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
                // after it.
                unsafe {
                    ptr::swap_nonoverlapping(mid.sub(left), mid, left);
                    mid = mid.add(left);
                }
                right -= left;
                if right < left {
                    break;
                }
            }
        }
        if (right == 0) || (left == 0) {
            return;
        }
    }
}