use entropy::probabilities::{InstancesToProbabilities, SymbolIndex, SymbolInfo};
pub use io::statistics::Instances;
use binjs_shared::{IOPath, IOPathItem};
use std;
use std::borrow::Borrow;
use std::cell::RefCell;
use std::collections::HashMap;
use std::hash::Hash;
use std::rc::Rc;
#[allow(unused_imports)] 
use itertools::Itertools;
use range_encoding;
#[derive(
    Add,
    Constructor,
    Eq,
    PartialEq,
    Ord,
    PartialOrd,
    Clone,
    Copy,
    From,
    Into,
    Debug,
    Hash,
    Serialize,
    Deserialize,
)]
struct DictionaryIndex(usize);
#[derive(
    Constructor,
    Eq,
    PartialEq,
    Ord,
    PartialOrd,
    Clone,
    Copy,
    Hash,
    Into,
    Debug,
    Serialize,
    Deserialize,
)]
struct BackReference(usize);
mod context_information {
    use super::Instances;
    use entropy::probabilities::{SymbolIndex, SymbolInfo};
    use std::cell::RefCell;
    use std::collections::HashMap;
    use std::hash::Hash;
    use std::rc::Rc;
    use itertools::Itertools;
    
    
    
    
    
    
    
    
    
    
    
    
    #[derive(Clone, Debug, Deserialize, Serialize)]
    pub struct ContextInformation<NodeValue, Statistics>
    where
        NodeValue: Eq + Hash,
    {
        
        stats_by_node_value: HashMap<NodeValue, Statistics>,
        
        
        
        
        
        value_by_symbol_index: Vec<NodeValue>,
    }
    impl<NodeValue, Statistics> ContextInformation<NodeValue, Statistics>
    where
        NodeValue: Eq + Hash,
    {
        pub fn new() -> Self {
            ContextInformation {
                stats_by_node_value: HashMap::new(),
                value_by_symbol_index: Vec::new(),
            }
        }
        
        pub fn len(&self) -> usize {
            self.stats_by_node_value.len()
        }
        pub fn into_iter(self) -> impl Iterator<Item = (NodeValue, Statistics)> {
            self.stats_by_node_value.into_iter()
        }
        pub fn iter(&self) -> impl Iterator<Item = (&NodeValue, &Statistics)> {
            self.stats_by_node_value.iter()
        }
    }
    
    impl<NodeValue> ContextInformation<NodeValue, SymbolInfo>
    where
        NodeValue: Eq + Hash,
    {
        pub fn stats_by_node_value(&self) -> &HashMap<NodeValue, SymbolInfo> {
            &self.stats_by_node_value
        }
        pub fn stats_by_node_value_mut(&mut self) -> &mut HashMap<NodeValue, SymbolInfo> {
            &mut self.stats_by_node_value
        }
        pub fn value_by_symbol_index(&self, index: SymbolIndex) -> Option<&NodeValue> {
            self.value_by_symbol_index.get(Into::<usize>::into(index))
        }
    }
    
    impl<NodeValue> ContextInformation<NodeValue, Instances>
    where
        NodeValue: Eq + Hash,
    {
        
        pub fn add(&mut self, node_value: NodeValue) {
            self.stats_by_node_value
                .entry(node_value)
                .and_modify(|instances| *instances += 1.into())
                .or_insert(1.into());
        }
        
        
        
        pub fn add_if_absent(&mut self, node_value: NodeValue) {
            self.stats_by_node_value
                .entry(node_value)
                .or_insert(1.into());
        }
    }
    impl<NodeValue> ::entropy::probabilities::InstancesToProbabilities
        for ContextInformation<NodeValue, Instances>
    where
        NodeValue: Clone + Eq + Hash + Ord,
    {
        type AsProbabilities = ContextInformation<NodeValue, SymbolInfo>;
        fn instances_to_probabilities(
            &self,
            _description: &str,
        ) -> ContextInformation<NodeValue, SymbolInfo> {
            let stats_by_node_value = self
                .stats_by_node_value
                .iter()
                .sorted_by(|(value_1, _), (value_2, _)| Ord::cmp(value_1, value_2)) 
                .collect::<Vec<_>>();
            let instances = stats_by_node_value
                .iter()
                .map(|(_, instances)| Into::<usize>::into(**instances) as u32)
                .collect();
            let distribution = std::rc::Rc::new(std::cell::RefCell::new(
                range_encoding::CumulativeDistributionFrequency::new(instances),
            ));
            let (stats_by_node_value, value_by_symbol_index): (HashMap<_, _>, Vec<_>) =
                stats_by_node_value
                    .into_iter()
                    .enumerate()
                    .map(|(index, (value, _))| {
                        let for_stats_by_node_value = (
                            value.clone(),
                            SymbolInfo {
                                index: index.into(),
                                distribution: distribution.clone(),
                            },
                        );
                        let for_value_by_symbol_index = value.clone();
                        (for_stats_by_node_value, for_value_by_symbol_index)
                    })
                    .unzip();
            ContextInformation {
                stats_by_node_value,
                value_by_symbol_index,
            }
        }
    }
    impl<NodeValue> ContextInformation<NodeValue, SymbolInfo>
    where
        NodeValue: Clone + Eq + Hash,
    {
        pub fn frequencies(
            &self,
        ) -> Option<&Rc<RefCell<range_encoding::CumulativeDistributionFrequency>>> {
            self.stats_by_node_value()
                .values()
                .next()
                .map(|any| &any.distribution)
        }
    }
}
use self::context_information::ContextInformation;
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct ContextPredict<Context, NodeValue, Statistics>
where
    Context: Eq + Hash + Clone,
    NodeValue: Eq + Hash + Clone,
{
    by_context: HashMap<Context, ContextInformation<NodeValue, Statistics>>,
}
impl<Context, NodeValue, Statistics> ContextPredict<Context, NodeValue, Statistics>
where
    Context: Eq + Hash + Clone,
    NodeValue: Eq + Hash + Clone,
{
    pub fn new() -> Self {
        Self {
            by_context: HashMap::new(),
        }
    }
    
    
    
    pub fn contexts(&self) -> impl Iterator<Item = &Context> {
        self.by_context.keys()
    }
    
    pub fn iter_mut(
        &mut self,
    ) -> impl Iterator<Item = (&Context, &mut ContextInformation<NodeValue, Statistics>)> {
        self.by_context.iter_mut()
    }
    pub fn iter(
        &self,
    ) -> impl Iterator<Item = (&Context, &ContextInformation<NodeValue, Statistics>)> {
        self.by_context.iter()
    }
    
    pub fn into_iter(
        self,
    ) -> impl Iterator<Item = (Context, ContextInformation<NodeValue, Statistics>)> {
        self.by_context.into_iter()
    }
    
    pub fn len(&self) -> usize {
        self.by_context.values().map(ContextInformation::len).sum()
    }
    
    pub fn iter_at<C: ?Sized>(&self, context: &C) -> impl Iterator<Item = (&NodeValue, &Statistics)>
    where
        Context: std::borrow::Borrow<C>,
        C: Hash + Eq,
    {
        std::iter::Iterator::flatten(
            self.by_context
                .get(context)
                .into_iter()
                .map(|info| info.iter()),
        )
    }
}
impl<Context, NodeValue> ContextPredict<Context, NodeValue, Instances>
where
    Context: Eq + Hash + Clone,
    NodeValue: Eq + Hash + Clone,
{
    
    pub fn add(&mut self, context: Context, value: NodeValue) {
        let stats_by_node_value = self
            .by_context
            .entry(context)
            .or_insert_with(|| ContextInformation::new());
        stats_by_node_value.add(value)
    }
    
    
    
    pub fn add_if_absent(&mut self, context: Context, value: NodeValue) {
        let stats_by_node_value = self
            .by_context
            .entry(context)
            .or_insert_with(|| ContextInformation::new());
        stats_by_node_value.add_if_absent(value)
    }
}
impl<Context, NodeValue> ContextPredict<Context, NodeValue, SymbolInfo>
where
    Context: Eq + Hash + Clone,
    NodeValue: Eq + Hash + Clone,
{
    
    
    
    
    
    
    
    
    
    
    
    
    pub fn value_by_symbol_index<C2: ?Sized>(
        &self,
        candidates: &[&C2],
        index: SymbolIndex,
    ) -> Option<&NodeValue>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        for context in candidates {
            if let Some(table) = self.by_context.get(context) {
                return table.value_by_symbol_index(index);
            }
        }
        None
    }
    
    
    
    
    
    
    
    pub fn frequencies_at<C2: ?Sized>(
        &self,
        candidates: &[&C2],
    ) -> Option<&Rc<RefCell<range_encoding::CumulativeDistributionFrequency>>>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        let table = self.context_info_at(candidates)?;
        table.frequencies()
    }
    
    
    
    
    
    
    
    
    pub fn stats_by_node_value<C2: ?Sized>(
        &self,
        candidates: &[&C2],
        value: &NodeValue,
    ) -> Option<&SymbolInfo>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        let context_info = self.context_info_at(candidates)?;
        context_info.stats_by_node_value().get(value)
    }
    pub fn stats_by_node_value_mut<C2: ?Sized>(
        &mut self,
        candidates: &[&C2],
        value: &NodeValue,
    ) -> Option<&mut SymbolInfo>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        let context_info = self.context_info_at_mut(candidates)?;
        context_info.stats_by_node_value_mut().get_mut(value)
    }
    fn context_info_at<C2: ?Sized>(
        &self,
        candidates: &[&C2],
    ) -> Option<&ContextInformation<NodeValue, SymbolInfo>>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        
        
        let mut found = None;
        for context in candidates {
            if self.by_context.get(context).is_some() {
                found = Some(context);
                break;
            }
        }
        match found {
            None => return None,
            Some(context) => return self.by_context.get(context),
        }
    }
    fn context_info_at_mut<C2: ?Sized>(
        &mut self,
        candidates: &[&C2],
    ) -> Option<&mut ContextInformation<NodeValue, SymbolInfo>>
    where
        Context: std::borrow::Borrow<C2>,
        C2: Hash + Eq,
    {
        
        
        let mut found = None;
        for context in candidates {
            if self.by_context.get(context).is_some() {
                found = Some(context);
                break;
            }
        }
        match found {
            None => return None,
            Some(context) => return self.by_context.get_mut(context),
        }
    }
}
impl<Context, NodeValue> InstancesToProbabilities for ContextPredict<Context, NodeValue, Instances>
where
    Context: Eq + Hash + Clone + std::fmt::Debug,
    NodeValue: Eq + Hash + Clone + Ord,
{
    type AsProbabilities = ContextPredict<Context, NodeValue, SymbolInfo>;
    fn instances_to_probabilities(
        &self,
        description: &str,
    ) -> ContextPredict<Context, NodeValue, SymbolInfo> {
        debug!(target: "entropy", "Converting ContextPredict {} to probabilities", description);
        let by_context = self
            .by_context
            .iter()
            .map(|(context, info)| {
                (
                    context.clone(),
                    info.instances_to_probabilities("ContextInformation"),
                )
            })
            .collect();
        ContextPredict { by_context }
    }
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct PathPredict<NodeValue, Statistics>
where
    NodeValue: Eq + Hash + Clone,
{
    
    
    
    
    
    depth: usize,
    
    context_predict: ContextPredict<IOPath, NodeValue, Statistics>,
}
impl<NodeValue> InstancesToProbabilities for PathPredict<NodeValue, Instances>
where
    NodeValue: Eq + Hash + Clone + Ord,
{
    type AsProbabilities = PathPredict<NodeValue, SymbolInfo>;
    fn instances_to_probabilities(&self, description: &str) -> PathPredict<NodeValue, SymbolInfo> {
        PathPredict {
            depth: self.depth,
            context_predict: self.context_predict.instances_to_probabilities(description),
        }
    }
}
impl<NodeValue, Statistics> PathPredict<NodeValue, Statistics>
where
    NodeValue: Eq + Hash + Clone,
{
    pub fn new(depth: usize) -> Self {
        PathPredict {
            depth,
            context_predict: ContextPredict::new(),
        }
    }
    
    pub fn len(&self) -> usize {
        self.context_predict.len()
    }
    
    
    
    
    
    pub fn depth(&self) -> usize {
        self.depth
    }
    
    
    
    pub fn paths(&self) -> impl Iterator<Item = &IOPath> {
        self.context_predict.contexts()
    }
    
    pub fn into_iter(
        self,
    ) -> impl Iterator<Item = (IOPath, ContextInformation<NodeValue, Statistics>)> {
        self.context_predict.into_iter()
    }
    pub fn iter(
        &self,
    ) -> impl Iterator<Item = (&IOPath, &ContextInformation<NodeValue, Statistics>)> {
        self.context_predict.iter()
    }
    
    fn tail<'a>(&self, path: &'a [IOPathItem]) -> &'a [IOPathItem] {
        Self::tail_of(path, self.depth)
    }
    
    
    
    fn tail_of<'a>(path: &'a [IOPathItem], depth: usize) -> &'a [IOPathItem] {
        let path = if path.len() <= depth {
            path
        } else {
            &path[path.len() - depth..]
        };
        path
    }
    pub fn iter_at(&self, path: &[IOPathItem]) -> impl Iterator<Item = (&NodeValue, &Statistics)> {
        let tail = self.tail(path);
        self.context_predict.iter_at(tail)
    }
}
impl<NodeValue> PathPredict<NodeValue, Instances>
where
    NodeValue: Eq + Hash + Clone,
{
    
    pub fn add(&mut self, path: &[IOPathItem], value: NodeValue) {
        let tail = self.tail(path);
        let mut as_path = IOPath::new();
        as_path.extend_from_slice(tail);
        self.context_predict.add(as_path, value);
    }
    pub fn add_fallback(&mut self, other: &Self) {
        debug_assert!(
            !other.paths().any(|path| path.len() > 1),
            "The fallback dictionary should only contain paths of length 0 or 1."
        );
        debug!(target: "dictionary", "Adding fallback of length {}", other.len());
        
        
        
        
        
        
        
        
        
        for (path, stats_by_node_value) in self.context_predict.iter_mut() {
            if path.len() == 0 {
                continue;
            }
            let tail = path.tail(1);
            for (value, statistics) in other.iter_at(tail) {
                debug_assert_eq!(Into::<usize>::into(*statistics), 1);
                stats_by_node_value.add_if_absent(value.clone());
            }
        }
        
        
        
        
        for (path, information) in other.iter() {
            for (value, statistics) in information.iter() {
                debug_assert_eq!(Into::<usize>::into(*statistics), 1);
                self.add_if_absent(path.borrow(), value.clone());
            }
        }
    }
    pub fn add_if_absent(&mut self, path: &[IOPathItem], value: NodeValue) {
        let tail = self.tail(path);
        let mut as_path = IOPath::new();
        as_path.extend_from_slice(tail);
        self.context_predict.add_if_absent(as_path, value);
    }
}
impl<NodeValue> PathPredict<NodeValue, SymbolInfo>
where
    NodeValue: Eq + Hash + Clone,
{
    
    
    
    
    pub fn value_by_symbol_index(
        &self,
        path: &[IOPathItem],
        index: SymbolIndex,
    ) -> Option<&NodeValue> {
        if path.len() >= 2 {
            let candidates = [
                
                self.tail(path),
                
                Self::tail_of(path, 1),
            ];
            self.context_predict
                .value_by_symbol_index(&candidates, index)
        } else {
            
            let candidates = [path];
            self.context_predict
                .value_by_symbol_index(&candidates, index)
        }
    }
    
    pub fn stats_by_node_value(&self, path: &[IOPathItem], value: &NodeValue) -> Option<&SymbolInfo>
    where
        NodeValue: std::fmt::Debug,
    {
        if path.len() >= 2 {
            let candidates = [
                
                self.tail(path),
                
                Self::tail_of(path, 1),
            ];
            self.context_predict.stats_by_node_value(&candidates, value)
        } else {
            
            let candidates = [path];
            self.context_predict.stats_by_node_value(&candidates, value)
        }
    }
    pub fn stats_by_node_value_mut(
        &mut self,
        path: &[IOPathItem],
        value: &NodeValue,
    ) -> Option<&mut SymbolInfo>
    where
        NodeValue: std::fmt::Debug,
    {
        if path.len() >= 2 {
            let candidates = [
                
                self.tail(path),
                
                Self::tail_of(path, 1),
            ];
            self.context_predict
                .stats_by_node_value_mut(&candidates, value)
        } else {
            
            let candidates = [path];
            self.context_predict
                .stats_by_node_value_mut(&candidates, value)
        }
    }
    
    pub fn frequencies_at(
        &self,
        path: &[IOPathItem],
    ) -> Option<&Rc<RefCell<range_encoding::CumulativeDistributionFrequency>>> {
        if path.len() >= 2 {
            let candidates = [
                
                self.tail(path),
                
                Self::tail_of(path, 1),
            ];
            self.context_predict.frequencies_at(&candidates)
        } else {
            
            let candidates = [path];
            self.context_predict.frequencies_at(&candidates)
        }
    }
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, PartialOrd, Ord)]
enum WindowPrediction {
    
    
    
    
    
    
    BackReference(BackReference),
    
    DictionaryIndex(DictionaryIndex),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WindowPredict<NodeValue, Statistics>
where
    NodeValue: Clone + Eq + Hash,
{
    
    width: usize,
    
    
    
    latest_values: Vec<NodeValue>,
    
    
    
    value_by_dictionary_index: Vec<NodeValue>,
    
    
    dictionary_index_by_value: HashMap<NodeValue, DictionaryIndex>,
    
    info: ContextInformation<WindowPrediction, Statistics>,
}
impl<NodeValue, Statistics> WindowPredict<NodeValue, Statistics>
where
    NodeValue: Clone + Eq + Hash,
{
    
    pub fn new(width: usize) -> Self {
        WindowPredict {
            width,
            value_by_dictionary_index: Vec::with_capacity(1024), 
            dictionary_index_by_value: HashMap::with_capacity(1024),
            latest_values: Vec::with_capacity(width),
            info: ContextInformation::new(),
        }
    }
    
    
    
    
    fn window_move_to_front(&mut self, index: BackReference) -> Result<(), ()> {
        let as_usize = Into::<usize>::into(index);
        if as_usize == 0 {
            
            
            return Ok(());
        }
        if as_usize >= self.latest_values.len() {
            return Err(());
        }
        let ref mut slice = self.latest_values.as_mut_slice()[0..as_usize];
        slice.rotate_right(1);
        Ok(())
    }
    
    
    
    
    
    fn window_insert_value(&mut self, value: &NodeValue) -> Option<BackReference> {
        
        if let Some(index) = self.latest_values.iter().position(|v| v == value) {
            let index = BackReference(index);
            self.window_move_to_front(index).unwrap(); 
            return Some(index);
        }
        if self.latest_values.len() < self.width {
            
            self.latest_values.insert(0, value.clone());
        } else {
            
            let slice = self.latest_values.as_mut_slice();
            slice[slice.len() - 1] = value.clone();
            slice.rotate_right(1);
        }
        None
    }
}
impl<NodeValue> WindowPredict<NodeValue, Instances>
where
    NodeValue: Clone + Eq + std::hash::Hash + std::fmt::Debug,
{
    pub fn add(&mut self, value: NodeValue) {
        
        debug!(target: "predict", "WindowPredict: Inserting value {:?}", value);
        let number_of_values = self.value_by_dictionary_index.len();
        let dictionary_index = *self
            .dictionary_index_by_value
            .entry(value.clone())
            .or_insert(DictionaryIndex(number_of_values));
        if dictionary_index == DictionaryIndex(number_of_values) {
            
            self.value_by_dictionary_index.push(value.clone());
        } else {
            
            let index: usize = dictionary_index.into();
            debug_assert_eq!(value, self.value_by_dictionary_index[index]);
        };
        
        
        let symbol = match self.window_insert_value(&value) {
            Some(backref) => {
                
                
                
                WindowPrediction::BackReference(backref)
            }
            None => {
                
                WindowPrediction::DictionaryIndex(dictionary_index)
            }
        };
        self.info.add(symbol);
    }
}
impl<NodeValue> WindowPredict<NodeValue, SymbolInfo>
where
    NodeValue: Clone + Eq + std::hash::Hash + std::fmt::Debug,
{
    
    pub fn frequencies(
        &self,
    ) -> Option<&Rc<RefCell<range_encoding::CumulativeDistributionFrequency>>> {
        self.info
            .stats_by_node_value()
            .values()
            .next()
            .map(|any| &any.distribution)
    }
    
    
    
    pub fn value_by_symbol_index(&mut self, index: SymbolIndex) -> Option<NodeValue> {
        match self.info.value_by_symbol_index(index) {
            None => None,
            Some(&WindowPrediction::DictionaryIndex(dictionary_index)) => {
                
                let result = self
                    .value_by_dictionary_index
                    .get(dictionary_index.0)?
                    .clone();
                self.window_insert_value(&result);
                Some(result)
            }
            Some(&WindowPrediction::BackReference(index)) => {
                let as_usize: usize = index.into();
                let result = self.latest_values.get(as_usize)?.clone();
                if let Err(_) = self.window_move_to_front(index) {
                    return None;
                }
                Some(result)
            }
        }
    }
    
    pub fn stats_by_node_value_mut(&mut self, value: &NodeValue) -> Option<&mut SymbolInfo> {
        
        
        debug!(target: "predict", "WindowPredict: Fetching {:?}", value);
        let prediction = match self.window_insert_value(value) {
            Some(backref) => WindowPrediction::BackReference(backref),
            None => {
                debug!(target: "predict", "WindowPredict: Value {:?} is not in the window, let's look for it in the dictionary", value);
                let index = self.dictionary_index_by_value.get(value)?.clone();
                WindowPrediction::DictionaryIndex(index)
            }
        };
        debug!(target: "predict", "WindowPredict: {:?} has just been inserted and will be encoded as {:?}", value, prediction);
        self.info.stats_by_node_value_mut().get_mut(&prediction)
    }
}
impl<NodeValue> InstancesToProbabilities for WindowPredict<NodeValue, Instances>
where
    NodeValue: Clone + Eq + Hash + Ord,
{
    type AsProbabilities = WindowPredict<NodeValue, SymbolInfo>;
    fn instances_to_probabilities(
        &self,
        _description: &str,
    ) -> WindowPredict<NodeValue, SymbolInfo> {
        WindowPredict {
            width: self.width,
            value_by_dictionary_index: self.value_by_dictionary_index.clone(),
            dictionary_index_by_value: self.dictionary_index_by_value.clone(),
            latest_values: Vec::with_capacity(self.width),
            info: self.info.instances_to_probabilities("WindowPredict::info"),
        }
    }
}