use io::statistics::Instances;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
#[derive(
Debug,
Default,
Display,
Serialize,
Deserialize,
From,
Into,
Add,
AddAssign,
Sub,
SubAssign,
Clone,
Copy,
PartialOrd,
Ord,
PartialEq,
Eq,
)]
pub struct BitLen(u8);
impl std::ops::Shl<BitLen> for u32 {
type Output = u32;
fn shl(self, rhs: BitLen) -> u32 {
self << Into::<u8>::into(rhs)
}
}
const MAX_CODE_BIT_LENGTH: u8 = 20;
#[derive(Debug)]
struct Key {
bits: u32,
bit_len: BitLen,
}
struct Node<T> {
instances: Instances,
content: NodeContent<T>,
}
enum NodeContent<T> {
Leaf(T),
Internal {
left: Box<NodeContent<T>>,
right: Box<NodeContent<T>>,
},
}
impl<T> PartialOrd for Node<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.instances.partial_cmp(&other.instances)
}
}
impl<T> Ord for Node<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.instances.cmp(&other.instances)
}
}
impl<T> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
self.instances.eq(&other.instances)
}
}
impl<T> Eq for Node<T> {}
#[derive(Debug)]
pub struct Keys<T>
where
T: Ord + Clone,
{
keys: Vec<(T, Key)>,
}
impl<T> Keys<T>
where
T: Ord + Clone,
{
pub fn from_sequence<S>(source: S, max_bit_len: u8) -> Result<Self, u8>
where
S: IntoIterator<Item = T>,
T: PartialEq + Hash,
{
let mut map = HashMap::new();
for item in source {
let counter = map.entry(item).or_insert(0.into());
*counter += 1.into();
}
Self::from_instances(map, max_bit_len)
}
pub fn from_instances<S>(source: S, max_bit_len: u8) -> Result<Self, u8>
where
S: IntoIterator<Item = (T, Instances)>,
{
let mut bit_lengths = Self::compute_bit_lengths(source, max_bit_len)?;
bit_lengths.sort_unstable_by_key(|&(ref value, ref bit_len)| (*bit_len, value.clone()));
let mut bits = 0;
let mut keys = Vec::with_capacity(bit_lengths.len());
for i in 0..bit_lengths.len() - 1 {
let (bit_len, symbol, next_bit_len) = (
bit_lengths[i].1,
bit_lengths[i].0.clone(),
bit_lengths[i + 1].1,
);
keys.push((symbol.clone(), Key { bits, bit_len }));
bits = (bits + 1) << (next_bit_len - bit_len);
}
let (ref symbol, bit_len) = bit_lengths[bit_lengths.len() - 1];
keys.push((symbol.clone(), Key { bits, bit_len }));
return Ok(Self { keys });
}
pub fn compute_bit_lengths<S>(source: S, max_bit_len: u8) -> Result<Vec<(T, BitLen)>, u8>
where
S: IntoIterator<Item = (T, Instances)>,
{
use std::cmp::Reverse;
let mut heap = BinaryHeap::new();
for (value, instances) in source {
if !instances.is_zero() {
heap.push(Reverse(Node {
instances,
content: NodeContent::Leaf(value),
}));
}
}
let len = heap.len();
if len == 0 {
return Ok(vec![]);
}
while heap.len() > 1 {
let left = heap.pop().unwrap();
let right = heap.pop().unwrap();
heap.push(Reverse(Node {
instances: left.0.instances + right.0.instances,
content: NodeContent::Internal {
left: Box::new(left.0.content),
right: Box::new(right.0.content),
},
}));
}
let root = heap.pop().unwrap();
let mut bit_lengths = Vec::with_capacity(len);
fn aux<T>(
bit_lengths: &mut Vec<(T, BitLen)>,
max_bit_len: u8,
depth: u8,
node: &NodeContent<T>,
) -> Result<(), u8>
where
T: Clone,
{
match *node {
NodeContent::Leaf(ref value) => {
if depth > max_bit_len {
return Err(depth);
}
bit_lengths.push((value.clone(), BitLen(depth)));
Ok(())
}
NodeContent::Internal {
ref left,
ref right,
} => {
aux(bit_lengths, max_bit_len, depth + 1, left)?;
aux(bit_lengths, max_bit_len, depth + 1, right)?;
Ok(())
}
}
}
aux(&mut bit_lengths, max_bit_len, 0, &root.0.content)?;
Ok(bit_lengths)
}
}
#[test]
fn test_coded_from_sequence() {
let sample = "appl";
let coded = Keys::from_sequence(sample.chars(), std::u8::MAX).unwrap();
assert_eq!(coded.keys.len(), 3);
assert_eq!(coded.keys[0].0, 'p');
assert_eq!(coded.keys[1].0, 'a');
assert_eq!(coded.keys[2].0, 'l');
assert_eq!(coded.keys[0].1.bit_len, 1.into());
assert_eq!(coded.keys[1].1.bit_len, 2.into());
assert_eq!(coded.keys[2].1.bit_len, 2.into());
assert_eq!(coded.keys[0].1.bits, 0b00);
assert_eq!(coded.keys[1].1.bits, 0b10);
assert_eq!(coded.keys[2].1.bits, 0b11);
assert_eq!(Keys::from_sequence(sample.chars(), 1).unwrap_err(), 2);
}