diff --git a/src/context_state.rs b/src/context_state.rs index 062d74d..92d5fb2 100644 --- a/src/context_state.rs +++ b/src/context_state.rs @@ -1,4 +1,4 @@ -use crate::rans::{PROB_BITS, ONE_PROB}; +use crate::rans::{ONE_PROB, PROB_BITS}; const INIT_PROB: u16 = 1 << (PROB_BITS - 1); const UPDATE_RATE: u32 = 4; @@ -33,7 +33,7 @@ impl<'a> Context<'a> { pub fn update(&mut self, bit: bool) { let old = self.state.contexts[self.index]; - self.state.contexts[self.index] = if bit { + self.state.contexts[self.index] = if !bit { old + ((ONE_PROB - old as u32 + UPDATE_ADD) >> UPDATE_RATE) as u8 } else { old - ((old as u32 + UPDATE_ADD) >> UPDATE_RATE) as u8 diff --git a/src/rans.rs b/src/rans.rs index e81308a..3250e8a 100644 --- a/src/rans.rs +++ b/src/rans.rs @@ -38,15 +38,15 @@ impl RansCoder { let mut state = 1 << l_bits; let mut byte = 0u8; - let mut bit = 8; + let mut bit = 0; let mut flush_state: Box = if self.use_bitstream { Box::new(|state: &mut u32| { - bit -= 1; byte |= ((*state & 1) as u8) << bit; - if bit == 0 { + bit += 1; + if bit == 8 { buffer.push(byte); byte = 0; - bit = 8; + bit = 0; } *state >>= 1; }) @@ -61,7 +61,7 @@ impl RansCoder { let max_state_factor: u32 = 1 << (l_bits + num_flush_bits - PROB_BITS); for step in self.bits.into_iter().rev() { let prob = step as u32 & 32767; - let (start, prob) = if step & 32768 != 0 { + let (start, prob) = if step & 32768 == 0 { (0, prob) } else { (prob, ONE_PROB - prob) @@ -118,7 +118,7 @@ impl CostCounter { impl EntropyCoder for CostCounter { fn encode_bit(&mut self, bit: bool, prob: u16) { - let prob = if bit { + let prob = if !bit { prob as u32 } else { ONE_PROB - prob as u32 @@ -163,8 +163,8 @@ impl<'a> RansDecoder<'a> { self.data = &self.data[1..]; self.bits_left = 8; } - self.state = (self.state << 1) | (self.byte & 1) as u32; - self.byte >>= 1; + self.state = (self.state << 1) | (self.byte >> 7) as u32; + self.byte <<= 1; self.bits_left -= 1; } } else { @@ -174,12 +174,12 @@ impl<'a> RansDecoder<'a> { } } - let bit = (self.state & PROB_MASK) < prob; + let bit = (self.state & PROB_MASK) >= prob; let (start, prob) = if bit { - (0, prob) - } else { (prob, ONE_PROB - prob) + } else { + (0, prob) }; self.state = prob * (self.state >> PROB_BITS) + (self.state & PROB_MASK) - start;