-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patharithmetic_coder.py
More file actions
97 lines (85 loc) · 2.74 KB
/
arithmetic_coder.py
File metadata and controls
97 lines (85 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from fractions import Fraction
def uniform(syms):
def pdist(seq):
return {s: Fraction(1, len(syms)) for s in syms}
return pdist
def cdfs(distribution):
cdf_lo = {}
cdf_hi = {}
total_probability = 0
for sym, prob in distribution.items():
cdf_lo[sym] = total_probability # we could also use the last symbol as the lower bound but this is more explicit
total_probability += prob
cdf_hi[sym] = total_probability
return cdf_lo, cdf_hi
def encode(sequence, pdist):
a = Fraction(0, 1)
b = Fraction(1, 1)
bits = []
for i, x in enumerate(sequence):
distribution = pdist(sequence[:i])
cdf_lo, cdf_hi = cdfs(distribution)
rang = b - a
b = a + rang * cdf_hi[x]
a += rang * cdf_lo[x]
# if our interval is sufficiently constrained, emit bits
while True:
if a < b < Fraction(1, 2):
bits.append(0)
a *= 2
b *= 2
elif Fraction(1, 2) < a < b:
bits.append(1)
a *= 2
b *= 2
a -= 1
b -= 1
else:
break
# it may still be the case that a and b are in [0,1], i.e. we have not emitted enough bits
# to constrain it in the decoder
while a > 0 or b < 1:
c = (a + b) / 2 # slightly suboptimal: try and write out midpoint
if c < Fraction(1, 2):
bits.append(0)
a *= 2
b *= 2
else:
bits.append(1)
a *= 2
a -= 1
b *= 2
b -= 1
return bits
def decode(bits, pdist):
a = Fraction(0, 1)
b = Fraction(1, 1)
scale = Fraction(1, 2)
sequence = []
while bits:
distribution = pdist(sequence)
cdf_lo, cdf_hi = cdfs(distribution)
for (lo, hi), sym in zip(zip(cdf_lo.values(), cdf_hi.values()), cdf_lo.keys()):
if lo <= a < b <= hi: # if message interval is subset of symbol interval, emit symbol
sequence.append(sym)
# expand working interval and adjust scale up
# (effectively, consume symbol in weird base)
range = hi - lo
a = (a - lo) / range
b = (b - lo) / range
scale /= range
break
else:
# consume a bit
bit = bits.pop(0)
a += bit * scale
b = a + scale
scale /= 2
return sequence
dist = lambda seq: {"a": Fraction(98, 100), "b": Fraction(1, 100), "c": Fraction(1, 100)}
input = "abbabaccabbabbbbbba"
bits = encode(input, dist)
print(bits)
result = "".join(decode(bits, dist))
print(input, result)
assert input == result, "oh no"