1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
//! Implementation of varnums used in Context-0.x.
//!
//! Note that this is NOT the same implementation as used in other experiments.
//!
//! # Format
//!
//! ```bnf
//! varu32 ::= [0b1nnnnnnn]{0..4} 0b0nnnnnnn
//! ```
//!
//! Bits are interpreted as little-endian.
//! A first bit of `1` indicates a non-last byte.
//! A first bit of `0` indicates a last byte.

use std::io::{self, Error, Read, Result, Write};

// --------------- Reading

/// A result of reading data from a byte-oriented stream.
#[derive(Debug, Clone)]
pub struct ByteValue<T> {
    /// The value read.
    pub value: T,

    /// The number of bytes consumed.
    pub byte_len: usize,
}

/// A reader that may read varu32-encoded u32 values from a stream.
pub trait ReadVaru32 {
    /// Read a single varu32.
    ///
    /// Note that this operation may return denormalized 0 values, e.g.
    /// `ByteValue { value: 0, byte_len: 5 }`. Such values may be used
    /// to represent exceptional cases.
    fn read_varu32_no_normalization(&mut self) -> Result<ByteValue<u32>>;
}

impl<T> ReadVaru32 for T
where
    T: Read,
{
    fn read_varu32_no_normalization(&mut self) -> Result<ByteValue<u32>> {
        let mut result: u32 = 0;
        let mut shift: u8 = 0;
        loop {
            debug_assert!(shift < 32);
            let mut bytes = [0];
            self.read_exact(&mut bytes)?;

            let byte = bytes[0];
            let new_result = result | (((byte & 0b01111111) as u32) << shift);
            if new_result < result {
                return Err(Error::new(
                    io::ErrorKind::InvalidData,
                    "Overflow during read_varu32_no_normalization",
                ));
            }

            result = new_result;
            shift += 7;

            if byte & 0b10000000 == 0 {
                return Ok(ByteValue {
                    value: result,
                    byte_len: (shift / 7) as usize,
                });
            }

            if shift >= 32 {
                return Err(Error::new(
                    io::ErrorKind::InvalidData,
                    "Overflow during read_varu32_no_normalization (too many digits)",
                ));
            }
        }
    }
}

#[test]
fn test_read_varu32_no_normalization_must_work() {
    use std::io::Cursor;
    for (input, expected) in &[
        // Various versions of 0
        (vec![0], 0),
        (vec![0b10000000, 0b00000000], 0),
        (vec![0b10000000, 0b10000000, 0b00000000], 0),
        // Various numbers in [0, 128)
        (vec![0b00000001], 1),
        (vec![0b00000010], 2),
        (vec![0b00000100], 4),
        (vec![0b00001000], 8),
        (vec![0b00010000], 16),
        (vec![0b00100000], 32),
        (vec![0b01000000], 64),
        (vec![0b01111111], 127),
        // Various numbers in [128, 16384(
        (vec![0b10000000, 0b00000001], 128),
        (vec![0b10000000, 0b00000010], 256),
        (vec![0b10000000, 0b00000100], 512),
        (vec![0b10000000, 0b00001000], 1024),
        (vec![0b10000000, 0b00010000], 2048),
        (vec![0b10000000, 0b00100000], 4096),
        (vec![0b10000000, 0b01000000], 8192),
        (vec![0b10000001, 0b00000001], 128 + 1),
        (vec![0b10000010, 0b00000001], 128 + 2),
        (vec![0b10000100, 0b00000001], 128 + 4),
        (vec![0b10001000, 0b00000001], 128 + 8),
        (vec![0b10010000, 0b00000001], 128 + 16),
        (vec![0b10100000, 0b00000001], 128 + 32),
        (vec![0b11000000, 0b00000001], 128 + 64),
        (vec![0b11111111, 0b01111111], 0b00111111_11111111),
        (vec![0b11111111, 0b00000000], 0b00000000_01111111),
        (vec![0b10000000, 0b01111111], 0b00111111_10000000),
    ] {
        // Read from the start of the vector.
        let mut cursor = Cursor::new(&input);
        let result = cursor
            .read_varu32_no_normalization()
            .unwrap_or_else(|e| panic!("Could not read from {:?}: {:?}", input, e));

        // Check value.
        assert_eq!(result.value, *expected);

        // Check that entire input was consumed.
        assert_eq!(result.byte_len, input.len());
    }
}

#[cfg(test)]
fn expect_error(buf: &Vec<u8>, kind: io::ErrorKind) {
    let mut cursor = io::Cursor::new(buf);
    match cursor.read_varu32_no_normalization() {
        Ok(_) => panic!("Expected an error"),
        Err(ref e) if e.kind() == kind => { /* good */ }
        Err(ref e) => panic!(
            "Expected {expected:?}, got error {e:?}",
            expected = kind,
            e = e
        ),
    }
}

#[test]
fn test_read_varu32_no_normalization_must_fail() {
    // A few values that make no sense because they end with `0b10000000`.
    let mut buf = Vec::new();
    for len in 0..5 {
        buf.resize(len, 0b10000000);
        {
            expect_error(&buf, io::ErrorKind::UnexpectedEof);
        }
        buf.clear();
    }

    // A value made up from too many bytes.
    buf.resize(8, 0b10000000);
    buf.push(0);
    {
        expect_error(&buf, io::ErrorKind::InvalidData);
    }
    buf.clear();
}

// ---------- Writing

/// A writer that may write varu32-encoded u32 values into a stream.
pub trait WriteVaru32 {
    /// Write a single varu32.
    ///
    /// Return the number of bytes written to the stream.
    fn write_varu32(&mut self, u32) -> Result<usize>;
}

impl<T> WriteVaru32 for T
where
    T: Write,
{
    fn write_varu32(&mut self, mut value: u32) -> Result<usize> {
        // Number of bytes written so far.
        let mut written = 0;
        while value > 0b01111111 {
            let byte = [0b10000000 | (value as u8 & 0b01111111)];
            self.write_all(&byte)?;
            written += 1;
            value >>= 7;
        }
        self.write_all(&[value as u8])?;
        Ok(written + 1)
    }
}

#[test]
fn test_write_varu32_must_work() {
    use std::io::Cursor;
    for (expected, input) in &[
        // Various numbers in [0, 128)
        (vec![0], 0),
        (vec![0b00000001], 1),
        (vec![0b00000010], 2),
        (vec![0b00000100], 4),
        (vec![0b00001000], 8),
        (vec![0b00010000], 16),
        (vec![0b00100000], 32),
        (vec![0b01000000], 64),
        (vec![0b01111111], 127),
        // Various numbers in [128, 16384(
        (vec![0b10000000, 0b00000001], 128),
        (vec![0b10000000, 0b00000010], 256),
        (vec![0b10000000, 0b00000100], 512),
        (vec![0b10000000, 0b00001000], 1024),
        (vec![0b10000000, 0b00010000], 2048),
        (vec![0b10000000, 0b00100000], 4096),
        (vec![0b10000000, 0b01000000], 8192),
        (vec![0b10000001, 0b00000001], 128 + 1),
        (vec![0b10000010, 0b00000001], 128 + 2),
        (vec![0b10000100, 0b00000001], 128 + 4),
        (vec![0b10001000, 0b00000001], 128 + 8),
        (vec![0b10010000, 0b00000001], 128 + 16),
        (vec![0b10100000, 0b00000001], 128 + 32),
        (vec![0b11000000, 0b00000001], 128 + 64),
        (vec![0b11111111, 0b01111111], 0b00111111_11111111),
        (vec![0b10000000, 0b01111111], 0b00111111_10000000),
    ] {
        let mut buf = Vec::new();
        let byte_len = buf.write_varu32(*input).unwrap();
        assert_eq!(buf.len(), byte_len);
        assert_eq!(byte_len, expected.len());
        assert_eq!(expected, &buf);
    }
}

#[test]
fn test_write_read_varu32() {
    // Test a bunch of arbitrary values.
    for i in 0..256 {
        let mut buf = Vec::new();
        for expected in &[i * i, i * i * i] {
            let bytes_written = buf.write_varu32(*expected).unwrap();
            assert_eq!(bytes_written, buf.len());

            let byte_value = io::Cursor::new(&buf)
                .read_varu32_no_normalization()
                .unwrap();
            assert_eq!(byte_value.byte_len, bytes_written);
            assert_eq!(byte_value.value, *expected);

            buf.clear();
        }
    }
}