diff --git a/src/greedy_packer.rs b/src/greedy_packer.rs index 684b719..2b7450f 100644 --- a/src/greedy_packer.rs +++ b/src/greedy_packer.rs @@ -1,9 +1,11 @@ -use crate::lz::LzCoder; +use crate::lz; use crate::match_finder::MatchFinder; +use crate::rans::RansCoder; pub fn pack(data: &[u8]) -> Vec { let match_finder = MatchFinder::new(data); - let mut lz = LzCoder::new(); + let mut rans_coder = RansCoder::new(); + let mut state = lz::CoderState::new(); let mut pos = 0; while pos < data.len() { @@ -12,14 +14,18 @@ pub fn pack(data: &[u8]) -> Vec { let max_offset = 1 << (m.length * 3 - 1).min(31); let offset = pos - m.pos; if offset < max_offset { - lz.encode_match(offset, m.length); + lz::Op::Match { + offset: offset as u32, + len: m.length as u32, + } + .encode(&mut rans_coder, &mut state); pos += m.length; encoded_match = true; } } if !encoded_match { - let offset = lz.last_offset(); + let offset = state.last_offset() as usize; if offset != 0 { let length = data[pos..] .iter() @@ -27,7 +33,11 @@ pub fn pack(data: &[u8]) -> Vec { .take_while(|(a, b)| a == b) .count(); if length > 0 { - lz.encode_match(offset, length); + lz::Op::Match { + offset: offset as u32, + len: length as u32, + } + .encode(&mut rans_coder, &mut state); pos += length; encoded_match = true; } @@ -35,10 +45,11 @@ pub fn pack(data: &[u8]) -> Vec { } if !encoded_match { - lz.encode_literal(data[pos]); + lz::Op::Literal(data[pos]).encode(&mut rans_coder, &mut state); pos += 1; } } - lz.finish() + lz::encode_eof(&mut rans_coder, &mut state); + rans_coder.finish() } diff --git a/src/lib.rs b/src/lib.rs index aa735ce..70b1b2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,8 @@ mod greedy_packer; mod lz; mod match_finder; mod rans; +mod parsing_packer; -pub use greedy_packer::pack; +pub use greedy_packer::pack as pack_fast; +pub use parsing_packer::pack; pub use lz::unpack; \ No newline at end of file diff --git a/src/lz.rs b/src/lz.rs index 6aa52b8..6eb128c 100644 --- a/src/lz.rs +++ b/src/lz.rs @@ -1,70 +1,86 @@ use crate::context_state::ContextState; -use crate::rans::{RansCoder, RansDecoder}; +use crate::rans::{EntropyCoder, RansDecoder}; -pub struct LzCoder { - contexts: ContextState, - range_coder: RansCoder, - last_offset: usize, +#[derive(Copy, Clone, Debug)] +pub enum Op { + Literal(u8), + Match { offset: u32, len: u32 }, } -impl LzCoder { - pub fn new() -> LzCoder { - LzCoder { +impl Op { + pub fn encode(&self, coder: &mut dyn EntropyCoder, state: &mut CoderState) { + match self { + &Op::Literal(lit) => { + encode_bit(coder, state, 0, false); + let mut context_index = 1; + for i in (0..8).rev() { + let bit = (lit >> i) & 1 != 0; + encode_bit(coder, state, context_index, bit); + context_index = (context_index << 1) | bit as usize; + } + } + &Op::Match { offset, len } => { + encode_bit(coder, state, 0, true); + encode_bit(coder, state, 256, offset != state.last_offset); + if offset != state.last_offset { + encode_length(coder, state, 257, offset + 1); + state.last_offset = offset; + } + encode_length(coder, state, 257 + 64, len); + } + } + } +} + +pub fn encode_eof(coder: &mut dyn EntropyCoder, state: &mut CoderState) { + encode_bit(coder, state, 0, true); + encode_bit(coder, state, 256, true); + encode_length(coder, state, 257, 1); +} + +fn encode_bit( + coder: &mut dyn EntropyCoder, + state: &mut CoderState, + context_index: usize, + bit: bool, +) { + coder.encode_with_context(bit, &mut state.contexts.context_mut(context_index)); +} + +fn encode_length( + coder: &mut dyn EntropyCoder, + state: &mut CoderState, + context_start: usize, + value: u32, +) { + assert!(value >= 1); + let top_bit = u32::BITS - 1 - value.leading_zeros(); + let mut context_index = context_start; + for i in 0..top_bit { + encode_bit(coder, state, context_index, true); + encode_bit(coder, state, context_index + 1, (value >> i) & 1 != 0); + context_index += 2; + } + encode_bit(coder, state, context_index, false); +} + +#[derive(Clone)] +pub struct CoderState { + contexts: ContextState, + last_offset: u32, +} + +impl CoderState { + pub fn new() -> CoderState { + CoderState { contexts: ContextState::new(1 + 255 + 1 + 64 + 64), - range_coder: RansCoder::new(), last_offset: 0, } } - pub fn encode_literal(&mut self, byte: u8) { - self.bit(false, 0); - let mut context_index = 1; - for i in (0..8).rev() { - let bit = (byte >> i) & 1 != 0; - self.bit(bit, context_index); - context_index = (context_index << 1) | bit as usize; - } - } - - pub fn encode_match(&mut self, offset: usize, length: usize) { - self.bit(true, 0); - if offset != self.last_offset { - self.last_offset = offset; - self.bit(true, 256); - self.length(offset + 1, 257); - } else { - self.bit(false, 256); - } - self.length(length, 257 + 64); - } - - pub fn finish(mut self) -> Vec { - self.bit(true, 0); - self.bit(true, 256); - self.length(1, 257); - self.range_coder.finish() - } - - pub fn last_offset(&self) -> usize { + pub fn last_offset(&self) -> u32 { self.last_offset } - - fn length(&mut self, value: usize, context_start: usize) { - assert!(value >= 1); - let top_bit = usize::BITS - 1 - value.leading_zeros(); - let mut context_index = context_start; - for i in 0..top_bit { - self.bit(true, context_index); - self.bit((value >> i) & 1 != 0, context_index + 1); - context_index += 2; - } - self.bit(false, context_index); - } - - fn bit(&mut self, b: bool, context_index: usize) { - self.range_coder - .encode_with_context(b, &mut self.contexts.context_mut(context_index)); - } } pub fn unpack(packed_data: &[u8]) -> Vec { diff --git a/src/main.rs b/src/main.rs index f261b67..0e11b89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ -use std::{fs::File, path::PathBuf}; -use std::io::prelude::*; use anyhow::{bail, Result}; +use std::io::prelude::*; +use std::{fs::File, path::PathBuf}; fn main() -> Result<()> { let mut args = pico_args::Arguments::from_env(); @@ -8,12 +8,18 @@ fn main() -> Result<()> { match args.subcommand()?.as_ref().map(|s| s.as_str()) { None => print_help(), Some("pack") => { + let fast = args.contains("--fast"); + let infile = args.free_from_os_str::(|s| Ok(s.into()))?; let outfile = args.free_from_os_str::(|s| Ok(s.into()))?; let mut data = vec![]; File::open(infile)?.read_to_end(&mut data)?; - let packed_data = upkr::pack(&data); + let packed_data = if fast { + upkr::pack_fast(&data) + } else { + upkr::pack(&data) + }; File::create(outfile)?.write_all(&packed_data)?; } Some("unpack") => { diff --git a/src/match_finder.rs b/src/match_finder.rs index ae70fd7..9c68051 100644 --- a/src/match_finder.rs +++ b/src/match_finder.rs @@ -58,7 +58,6 @@ impl MatchFinder { right_index: index, right_length: usize::MAX, current_length: 0, - patience_left: 0, matches_left: self.max_matches, max_length: 0, queue: BinaryHeap::new(), @@ -79,7 +78,6 @@ pub struct Matches<'a> { right_index: usize, right_length: usize, current_length: usize, - patience_left: usize, matches_left: usize, max_length: usize, queue: BinaryHeap, @@ -103,9 +101,7 @@ impl<'a> Iterator for Matches<'a> { { return None; } - self.patience_left = self.finder.patience; while self.matches_left > 0 - && self.patience_left > 0 && (self.left_length == self.current_length || self.right_length == self.current_length) { @@ -129,30 +125,43 @@ impl<'a> Iterator for Matches<'a> { impl<'a> Matches<'a> { fn move_left(&mut self) { - if self.left_index > 0 { + let mut patience = self.finder.patience; + while self.left_length > 0 && patience > 0 && self.left_index > 0 { self.left_index -= 1; self.left_length = self .left_length .min(self.finder.lcp[self.left_index] as usize); - } else { - self.left_length = 0; + if self + .pos_range + .contains(&(self.finder.suffixes[self.left_index] as usize)) + { + return; + } + patience -= 1; } + self.left_length = 0; } fn move_right(&mut self) { - self.right_index += 1; - self.right_length = self - .right_length - .min(self.finder.lcp[self.right_index - 1] as usize); + let mut patience = self.finder.patience; + while self.right_length > 0 && patience > 0 && self.right_index + 1 < self.finder.suffixes.len() { + self.right_index += 1; + self.right_length = self + .right_length + .min(self.finder.lcp[self.right_index - 1] as usize); + if self + .pos_range + .contains(&(self.finder.suffixes[self.right_index] as usize)) + { + return; + } + patience -= 1; + } + self.right_length = 0; } fn add_to_queue(&mut self, pos: i32) { - if self.pos_range.contains(&(pos as usize)) { - self.queue.push(pos as usize); - self.matches_left -= 1; - self.patience_left = self.finder.patience; - } else { - self.patience_left = 0; - } + self.queue.push(pos as usize); + self.matches_left -= 1; } } diff --git a/src/parsing_packer.rs b/src/parsing_packer.rs new file mode 100644 index 0000000..e3382d0 --- /dev/null +++ b/src/parsing_packer.rs @@ -0,0 +1,122 @@ +use std::collections::HashMap; +use std::rc::Rc; + +use crate::match_finder::MatchFinder; +use crate::rans::{RansCoder, CostCounter}; +use crate::lz; + +pub fn pack(data: &[u8]) -> Vec { + let mut parse = parse(data); + let mut ops = vec![]; + while let Some(link) = parse { + ops.push(link.op); + parse = link.prev.clone(); + } + let mut state = lz::CoderState::new(); + let mut coder = RansCoder::new(); + for op in ops.into_iter().rev() { + op.encode(&mut coder, &mut state); + } + lz::encode_eof(&mut coder, &mut state); + coder.finish() +} + +struct Parse { + prev: Option>, + op: lz::Op, +} + +struct Arrival { + parse: Option>, + state: lz::CoderState, + cost: f64, +} + +type Arrivals = HashMap>; + +const MAX_ARRIVALS: usize = 16; + +fn parse(data: &[u8]) -> Option> { + let match_finder = MatchFinder::new(data); + + let mut arrivals: Arrivals = HashMap::new(); + fn add_arrival(arrivals: &mut Arrivals, pos: usize, arrival: Arrival) { + let vec = arrivals.entry(pos).or_default(); + if vec.len() < MAX_ARRIVALS || vec[MAX_ARRIVALS - 1].cost > arrival.cost { + vec.push(arrival); + vec.sort_by(|a, b| { + a.cost + .partial_cmp(&b.cost) + .unwrap_or(std::cmp::Ordering::Equal) + }); + if vec.len() > MAX_ARRIVALS { + vec.pop(); + } + } + } + fn add_match(arrivals: &mut Arrivals, pos: usize, offset: usize, length: usize, arrival: &Arrival) { + let mut cost_counter = CostCounter(0.); + let mut state = arrival.state.clone(); + let op = lz::Op::Match { offset: offset as u32, len: length as u32 }; + op.encode(&mut cost_counter, &mut state); + add_arrival( + arrivals, + pos + length, + Arrival { + parse: Some(Rc::new(Parse { + prev: arrival.parse.clone(), + op, + })), + state, + cost: arrival.cost + cost_counter.0, + }, + ); + } + add_arrival( + &mut arrivals, + 0, + Arrival { + parse: None, + state: lz::CoderState::new(), + cost: 0.0, + }, + ); + for pos in 0..data.len() { + for arrival in arrivals.remove(&pos).unwrap() { + let mut found_last_offset = false; + for m in match_finder.matches(pos) { + let offset = pos - m.pos; + if offset as u32 == arrival.state.last_offset() { + found_last_offset = true; + } + add_match(&mut arrivals, pos, offset, m.length, &arrival); + } + + if !found_last_offset && arrival.state.last_offset() > 0 { + let offset = arrival.state.last_offset() as usize; + let length = data[pos..].iter().zip(data[(pos - offset)..].iter()).take_while(|(a, b)| a == b).count(); + if length > 0 { + add_match(&mut arrivals, pos, offset, length, &arrival); + } + } + + let mut cost_counter = CostCounter(0.); + let mut state = arrival.state; + let op = lz::Op::Literal(data[pos]); + op.encode(&mut cost_counter, &mut state); + add_arrival( + &mut arrivals, + pos + 1, + Arrival { + parse: Some(Rc::new(Parse { + prev: arrival.parse, + op, + })), + state, + cost: arrival.cost + cost_counter.0, + }, + ); + } + } + arrivals.remove(&data.len()).unwrap()[0].parse.clone() +} diff --git a/src/rans.rs b/src/rans.rs index 1c798b6..1f4445a 100644 --- a/src/rans.rs +++ b/src/rans.rs @@ -4,23 +4,29 @@ const L_BITS: u32 = 16; pub const PROB_BITS: u32 = 12; pub const ONE_PROB: u32 = 1 << PROB_BITS; +pub trait EntropyCoder { + fn encode_bit(&mut self, bit: bool, prob: u16); + + fn encode_with_context(&mut self, bit: bool, context: &mut Context) { + self.encode_bit(bit, context.prob()); + context.update(bit); + } +} + pub struct RansCoder(Vec); +impl EntropyCoder for RansCoder { + fn encode_bit(&mut self, bit: bool, prob: u16) { + assert!(prob < 32768); + self.0.push(prob | ((bit as u16) << 15)); + } +} + impl RansCoder { pub fn new() -> RansCoder { RansCoder(Vec::new()) } - pub fn encode_with_context(&mut self, bit: bool, context: &mut Context) { - self.encode_bit(bit, context.prob()); - context.update(bit); - } - - pub fn encode_bit(&mut self, bit: bool, prob: u16) { - assert!(prob < 32768); - self.0.push(prob | ((bit as u16) << 15)); - } - pub fn finish(self) -> Vec { let mut buffer = vec![]; let mut state = 1 << L_BITS; @@ -51,6 +57,16 @@ impl RansCoder { } } +pub struct CostCounter(pub f64); + +impl EntropyCoder for CostCounter { + fn encode_bit(&mut self, bit: bool, prob: u16) { + let prob = if bit { prob as u32 } else { ONE_PROB - prob as u32 }; + let inv_prob = ONE_PROB as f64 / prob as f64; + self.0 += inv_prob.log2(); + } +} + pub struct RansDecoder<'a> { data: &'a [u8], state: u32,