Skip to content

Adaptive Probability Maps #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions entropy/ari/apm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*!

Adaptive Probability Models

# Links
* http://mattmahoney.net/dc/bbb.cpp
* https://github.com/IlyaGrebnov/libbsc

# Example

# Credit
Matt Mahoney for the wonderful 'bbb' commented source

*/

use super::Border;
pub type FlatProbability = u16;
pub type WideProbability = i16;

static BIN_WEIGHT_BITS: uint = 8;
static BIN_WEIGHT_TOTAL: uint = 1<<BIN_WEIGHT_BITS;
static FLAT_BITS: FlatProbability = 12;
static FLAT_TOTAL: int = 1<<FLAT_BITS;
static WIDE_BITS: uint = 12;
static WIDE_OFFSET: WideProbability = 1<<(WIDE_BITS-1);
//static WIDE_TOTAL: int = (1<<WIDE_BITS)+1;
static PORTAL_OFFSET: uint = 1<<(WIDE_BITS-BIN_WEIGHT_BITS-1);
static PORTAL_BINS: uint = 2*PORTAL_OFFSET + 1;


/// Bit probability model
pub struct Bit(FlatProbability);

impl Bit {
/// Return an equal 0-1 probability
#[inline]
pub fn new_equal() -> Bit {
Bit(FLAT_TOTAL as FlatProbability >> 1)
}

/// Return flat probability
#[inline]
pub fn to_flat(&self) -> FlatProbability {
let Bit(fp) = *self;
fp
}

/// Return wide probability
#[inline]
pub fn to_wide(&self) -> WideProbability {
//table_stretch[self.to_flat() as uint]
let p = (self.to_flat() as f32) / (FLAT_TOTAL as f32);
let d = (p / (1.0-p)).ln();
let wp = (d * WIDE_OFFSET as f32).to_i16().unwrap();
wp
}

/// Construct from flat probability
#[inline]
pub fn from_flat(fp: FlatProbability) -> Bit {
Bit(fp)
}

/// Construct from wide probability
#[inline]
pub fn from_wide(wp: WideProbability) -> Bit {
//Bit(table_squash[(wp+WIDE_OFFSET) as uint])
let d = (wp as f32) / (WIDE_OFFSET as f32);
let p = 1.0 / (1.0 + (-d).exp());
let fp = (p * FLAT_TOTAL as f32).to_u16().unwrap();
Bit(fp)
}

/// Mutate for better zeroes
pub fn update_zero(&mut self, rate: int, bias: int) {
let &Bit(ref mut fp) = self;
let one = FLAT_TOTAL - bias - (*fp as int);
*fp += (one >> rate) as FlatProbability;
}

/// Mutate for better ones
pub fn update_one(&mut self, rate: int, bias: int) {
let &Bit(ref mut fp) = self;
let zero = (*fp as int) - bias;
*fp -= (zero >> rate) as FlatProbability;
}

/// Mutate for a given value
#[inline]
pub fn update(&mut self, value: bool, rate: int, bias: int) {
if !value {
self.update_zero(rate, bias)
}else {
self.update_one(rate, bias)
}
}
}

impl super::Model<bool> for Bit {
fn get_range(&self, value: bool) -> (Border,Border) {
let fp = self.to_flat() as Border;
if !value {
(0, fp)
}else {
(fp, FLAT_TOTAL as Border)
}
}

fn find_value(&self, offset: Border) -> (bool,Border,Border) {
assert!(offset < FLAT_TOTAL as Border,
"Invalid bit offset {} requested", offset);
let fp = self.to_flat() as Border;
if offset < fp {
(false, 0, fp)
}else {
(true, fp, FLAT_TOTAL as Border)
}
}

fn get_denominator(&self) -> Border {
FLAT_TOTAL as Border
}
}


/// Binary context gate
/// maps an input binary probability into a new one
/// by interpolating between internal maps in non-linear space
pub struct Gate {
map: [Bit, ..PORTAL_BINS],
}

pub type BinCoords = (uint, uint); // (index, weight)

impl Gate {
/// Create a new gate instance
pub fn new() -> Gate {
let mut g = Gate {
map: [Bit::new_equal(), ..PORTAL_BINS],
};
for (i,bit) in g.map.mut_iter().enumerate() {
let rp = (i as f32)/(PORTAL_OFFSET as f32) - 1.0;
let wp = (rp * (WIDE_OFFSET as f32)).to_i16().unwrap();
*bit = Bit::from_wide(wp);
}
g
}

/// Pass a bit through the gate
#[inline]
pub fn pass(&self, bit: &Bit) -> (Bit, BinCoords) {
let (fp, index) = self.pass_wide(bit.to_wide());
(Bit::from_flat(fp), index)
}

/// Pass a wide probability on input, usable when
/// you mix it linearly beforehand (libbsc does that)
pub fn pass_wide(&self, wp: WideProbability) -> (FlatProbability, BinCoords) {
let index = ((wp + WIDE_OFFSET) >> BIN_WEIGHT_BITS) as uint;
let weight = wp as uint & (BIN_WEIGHT_TOTAL-1);
let z = [
self.map[index+0].to_flat() as uint,
self.map[index+1].to_flat() as uint];
let sum = z[0]*(BIN_WEIGHT_TOTAL-weight) + z[1]*weight;
let fp = (sum >> BIN_WEIGHT_BITS) as FlatProbability;
(fp, (index, weight))
}

//TODO: weight update ratio & bias as well

/// Mutate for better zeroes
pub fn update_zero(&mut self, bc: BinCoords, rate: int, bias: int) {
let (index, _) = bc;
self.map[index+0].update_zero(rate, bias);
self.map[index+1].update_zero(rate, bias);
}

/// Mutate for better ones
pub fn update_one(&mut self, bc: BinCoords, rate: int, bias: int) {
let (index, _) = bc;
self.map[index+0].update_one(rate, bias);
self.map[index+1].update_one(rate, bias);
}

/// Mutate for a given value
#[inline]
pub fn update(&mut self, value: bool, bc: BinCoords, rate: int, bias: int) {
if !value {
self.update_zero(bc, rate, bias)
}else {
self.update_one(bc, rate, bias)
}
}
}
27 changes: 15 additions & 12 deletions entropy/ari/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,27 @@ pub struct Model {
zero: Border,
/// total frequency (constant)
total: Border,
/// learning rate
pub rate: Border,
}

impl Model {
/// Create a new flat (50/50 probability) instance
pub fn new_flat(threshold: Border) -> Model {
assert!(threshold >= 2);
pub fn new_flat(threshold: Border, rate: Border) -> Model {
Model {
zero: threshold>>1,
total: threshold,
rate: rate,
}
}

/// Create a new instance with a given percentage for zeroes
pub fn new_custom(zero_percent: u8, threshold: Border) -> Model {
pub fn new_custom(zero_percent: u8, threshold: Border, rate: Border) -> Model {
assert!(threshold >= 100);
Model {
zero: (zero_percent as Border)*threshold/100,
total: threshold,
rate: rate,
}
}

Expand All @@ -56,24 +59,24 @@ impl Model {
}

/// Update the frequency of zero
pub fn update_zero(&mut self, factor: uint) {
debug!("\tUpdating zero by a factor of {}", factor);
self.zero += (self.total-self.zero) >> factor;
pub fn update_zero(&mut self) {
debug!("\tUpdating zero");
self.zero += (self.total-self.zero) >> self.rate;
}

/// Update the frequency of one
pub fn update_one(&mut self, factor: uint) {
debug!("\tUpdating one by a factor of {}", factor);
self.zero -= self.zero >> factor;
pub fn update_one(&mut self) {
debug!("\tUpdating one");
self.zero -= self.zero >> self.rate;
}

/// Update frequencies in favor of given 'value'
/// Lower factors produce more aggressive updates
pub fn update(&mut self, value: bool, factor: uint) {
pub fn update(&mut self, value: bool) {
if value {
self.update_one(factor)
self.update_one()
}else {
self.update_zero(factor)
self.update_zero()
}
}
}
Expand Down
1 change: 1 addition & 0 deletions entropy/ari/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use std::io::IoResult;

pub use self::table::{ByteDecoder, ByteEncoder};

pub mod apm;
pub mod bin;
pub mod table;
#[cfg(test)]
Expand Down
66 changes: 52 additions & 14 deletions entropy/ari/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@ fn roundtrip(bytes: &[u8]) {
assert_eq!(bytes.as_slice(), decoded.as_slice());
}

fn encode_binary(bytes: &[u8], model: &mut super::bin::Model, factor: uint) -> Vec<u8> {
fn encode_binary(bytes: &[u8], model: &mut super::bin::Model) -> Vec<u8> {
let mut encoder = super::Encoder::new(MemWriter::new());
for &byte in bytes.iter() {
for i in range(0,8) {
let bit = (byte & (1<<i)) != 0;
encoder.encode(bit, model).unwrap();
model.update(bit, factor);
model.update(bit);
}
}
let (writer, err) = encoder.finish();
err.unwrap();
writer.unwrap()
}

fn roundtrip_binary(bytes: &[u8], factor: uint) {
let mut bm = super::bin::Model::new_flat(super::RANGE_DEFAULT_THRESHOLD >> 3);
let output = encode_binary(bytes, &mut bm, factor);
fn roundtrip_binary(bytes: &[u8], factor: u32) {
let mut bm = super::bin::Model::new_flat(super::RANGE_DEFAULT_THRESHOLD >> 3, factor);
let output = encode_binary(bytes, &mut bm);
bm.reset_flat();
let mut decoder = super::Decoder::new(BufReader::new(output.as_slice()));
for &byte in bytes.iter() {
let mut value = 0u8;
for i in range(0,8) {
let bit = decoder.decode(&bm).unwrap();
bm.update(bit, factor);
bm.update(bit);
value += (bit as u8)<<i;
}
assert_eq!(value, byte);
Expand Down Expand Up @@ -87,15 +87,13 @@ fn roundtrip_term(bytes1: &[u8], bytes2: &[u8]) {

fn roundtrip_proxy(bytes: &[u8]) {
// prepare data
let factor0 = 3;
let factor1 = 5;
let update0 = 10;
let update1 = 5;
let threshold = super::RANGE_DEFAULT_THRESHOLD >> 3;
let mut t0 = super::table::Model::new_flat(16, threshold);
let mut t1 = super::table::Model::new_flat(16, threshold);
let mut b0 = super::bin::Model::new_flat(threshold);
let mut b1 = super::bin::Model::new_flat(threshold);
let mut b0 = super::bin::Model::new_flat(threshold, 3);
let mut b1 = super::bin::Model::new_flat(threshold, 5);
// encode (high 4 bits with the proxy table, low 4 bits with the proxy binary)
let mut encoder = super::Encoder::new(MemWriter::new());
for &byte in bytes.iter() {
Expand All @@ -112,8 +110,8 @@ fn roundtrip_proxy(bytes: &[u8]) {
let proxy = super::bin::SumProxy::new(1, &b0, 1, &b1, 1);
encoder.encode(bit, &proxy).unwrap();
}
b0.update(bit, factor0);
b1.update(bit, factor1);
b0.update(bit);
b1.update(bit);
}
}
let (writer, err) = encoder.finish();
Expand All @@ -139,13 +137,48 @@ fn roundtrip_proxy(bytes: &[u8]) {
decoder.decode(&proxy).unwrap()
};
value += (bit as u8)<<i;
b0.update(bit, factor0);
b1.update(bit, factor1);
b0.update(bit);
b1.update(bit);
}
assert_eq!(value, byte);
}
}

fn roundtrip_apm(bytes: &[u8]) {
let mut bit = super::apm::Bit::new_equal();
let mut gate = super::apm::Gate::new();
let mut encoder = super::Encoder::new(MemWriter::new());
for b8 in bytes.iter() {
for i in range(0,8) {
let b1 = (*b8>>i) & 1 != 0;
let (bit_new, coords) = gate.pass(&bit);
encoder.encode(b1, &bit_new).unwrap();
bit.update(b1, 10, 0);
gate.update(b1, coords, 10, 0);
}
}
let (writer, err) = encoder.finish();
err.unwrap();
let output = writer.unwrap();
bit = super::apm::Bit::new_equal();
gate = super::apm::Gate::new();
let mut decoder = super::Decoder::new(BufReader::new(output.as_slice()));
for b8 in bytes.iter() {
let mut decoded = 0u8;
for i in range(0,8) {
let (bit_new, coords) = gate.pass(&bit);
let b1 = decoder.decode(&bit_new).unwrap();
if b1 {
decoded += 1<<i;
}
bit.update(b1, 10, 0);
gate.update(b1, coords, 10, 0);
}
assert_eq!(decoded, *b8);
}
}


#[test]
fn roundtrips() {
roundtrip(bytes!("abracadabra"));
Expand All @@ -170,6 +203,11 @@ fn roundtrips_proxy() {
roundtrip_proxy(TEXT_INPUT);
}

#[test]
fn roundtrips_apm() {
roundtrip_apm(bytes!("abracadabra"));
}

#[bench]
fn compress_speed(bh: &mut Bencher) {
let mut storage = Vec::from_elem(TEXT_INPUT.len(), 0u8);
Expand Down