# pack UTF-8 shift-based DFA into 32-bit integers
#
# based on work by Dougall Johnson:
# https://gist.github.com/dougallj/166e326de6ad4cf2c94be97a204c025f

from z3 import *

####### spec for 32-bit shift-based DFA

num_states = 9

# unique state->state transitions for all character classes
transitions = [
#ER BGN CS1 2 3 P3A B P4A B
 (0, 0, 0, 0, 0, 0, 0, 0, 0), # ILL
 (0, 1, 0, 0, 0, 0, 0, 0, 0), # ASC
 (0, 2, 0, 0, 0, 0, 0, 0, 0), # L2A
 (0, 3, 0, 0, 0, 0, 0, 0, 0), # L3B
 (0, 4, 0, 0, 0, 0, 0, 0, 0), # L4B
 (0, 5, 0, 0, 0, 0, 0, 0, 0), # L3A
 (0, 6, 0, 0, 0, 0, 0, 0, 0), # L3C
 (0, 7, 0, 0, 0, 0, 0, 0, 0), # L4A
 (0, 8, 0, 0, 0, 0, 0, 0, 0), # L4C
 (0, 0, 1, 2, 3, 0, 2, 0, 3), # CR1
 (0, 0, 1, 2, 3, 0, 2, 3, 0), # CR2
 (0, 0, 1, 2, 3, 2, 0, 3, 0)  # CR3
]

s = Solver()

offsets = [BitVec('o%d' % i, 32) for i in range(num_states)]
values  = [BitVec('v%d' % i, 32) for i in range(len(transitions))]

for i in range(len(offsets)):
	s.add(offsets[i] < 32)
	for j in range(i+1, len(offsets)):
		s.add(offsets[i] != offsets[j])

for vn, targets in enumerate((transitions)):
	for off, target in enumerate(targets):
		s.add(((values[vn] >> offsets[off]) & 31) == offsets[target])


####### not strictly necessary, but keep things consistent

# set error state to zero
s.add(offsets[0] == 0)

# avoid sign extension
for v in values[1:]:
	s.add(v > 0)

# force transitions to match expressions based on the states (i.e. keep "dark" bits out)
s.add(values[0] == 0)
for vlead in values[1:9]:
	s.add(vlead & (31 << offsets[1]) == vlead)
s.add(values[9]  == (offsets[1] << offsets[2]) | (offsets[2] << offsets[3]) | (offsets[3] << offsets[4]) | (offsets[2] << offsets[6]) | (offsets[3] << offsets[8]))
s.add(values[10] == (offsets[1] << offsets[2]) | (offsets[2] << offsets[3]) | (offsets[3] << offsets[4]) | (offsets[2] << offsets[6]) | (offsets[3] << offsets[7]))
s.add(values[11] == (offsets[1] << offsets[2]) | (offsets[2] << offsets[3]) | (offsets[3] << offsets[4]) | (offsets[2] << offsets[5]) | (offsets[3] << offsets[7]))

# use increasing order where possible to make it look nicer
s.add(offsets[4] < offsets[5])
s.add(offsets[5] < offsets[6])
s.add(offsets[6] < offsets[7])
s.add(offsets[7] < offsets[8])


####### run the solver

if s.check() == sat:
	offsets = [s.model()[i].as_long() for i in offsets]
	print('offsets:', offsets)

	values = [s.model()[i].as_long() for i in values]
	print('transitions:')
	for v in values:
		print(format(v, '032b'), hex(v))
