I had a blast working through the problem 24 of AdventOfCode 2021. The size of the input (9^14 possibilities) forced us to get rid of the brute force idea and try many things. Reverse engineering a program in a custom language was fun, using Z3 is always cool, and many people were very creative.
There are many ways to think about this problem, and after being stuck for a while, I spent some time on the subreddit understanding the different approaches and rewriting different solutions myself.
Here is a breakdown of some approaches that resonated with me ; they are conceptually very different.
I’ve studied ad hoc solutions (solutions tailor-made for this particular ALU program) and generic solutions (solutions that could work for any ALU program).
Advent of Code 2021 Day 24 is described here. Here is a summary:
You are given an ALU program, that consist of 252 instructions. 14 inputs values, whose values all are between 1 and 9. It has integer variables w, x, y, and z. These variables all start with the value 0. The ALU also supports six instructions:
In all of these instructions, a and b are placeholders; a will always be the variable where the result of the operation is stored (one of w, x, y, or z), while b can be either a variable or a number. Numbers can be positive or negative, but will always be integers.
What is the highest input so that register z=0 at the end of the program ?
In this first part, we’ll take a look at the solutions where people reverse-engineered the ALU program and used some of its properties to solve the problem.
The ALU program we need to study has 14 times the same 18-instruction program. Only 3 elements will vary:
inp w
mul x 0
add x z
mod x 26
div z 1 -> this value can change
add x 11 -> this value can change
eql x w
eql x 0
mul y 0
add y 25
mul y x
add y 1
mul z y
mul y 0
add y w
add y 16 -> this value can change
mul y x
add z y
During the remainder of this article, we’ll give the same name to those parameters:
zdiv
xcheck
yadd
We can first write a piece of code in order to extract these elements.
# common24.py
import re
def extract_parameters(program):
repeated_program = r"""inp w
mul x 0
add x z
mod x 26
div z (.*)
add x (.*)
eql x w
eql x 0
mul y 0
add y 25
mul y x
add y 1
mul z y
mul y 0
add y w
add y (.*)
mul y x
add z y"""
div_check_add = re.findall(repeated_program, program)
assert (len(div_check_add) == 14), len(div_check_add)
return [list(map(int, dca)) for dca in div_check_add]
Now, what is this 18 line program doing ? If we reverse-engineer the 18 instructions, we can rewrite it as a piece of python code:
# based on
# https://www.reddit.com/r/adventofcode/comments/rnejv5/comment/hpuaphy/?utm_source=share&utm_medium=web2x&context=3
def subroutine(w, zdiv, xcheck, yadd):
"""w is our input"""
x = z % 26 # store the last element of z in x
z /= zdiv # if zdiv == 26, pop the last element of z
if x != w - xcheck:
z = 26 * z + w + yadd # push w + yadd to z
return z
This subroutine is called repeatedly for all the inputs, with the various values of zdiv, xcheck and yadd. The resulting z is then used as an input to the next subroutine. Only the z register really changes: y is not used, and x is reset in each subroutine.
Given this knowledge, there are now many ways to solve this problem.
We can solve the problem by only knowing the subroutine: given the output, what are the inputs than can produce it? We start from the end, where we know that z must be 0. We extract a list of the possible input that made z reach zero, and the corresponding starting z. Then we start over with the previous block, using our newly found list of z, and so on until the end. For part 1, we pick the highest inputs, for part 2 we keep the lowest.
from common24 import extract_parameters
import itertools as it
# https://gist.github.com/jkseppan/1e36172ad4f924a8f86a920e4b1dc1b1
def backward(xcheck, yadd, zdiv, z_final, w):
"""Returns the possible values of z before a single block
if the final value of z is z_final and input is w
"""
zs = []
x = z_final - w - yadd
if x % 26 == 0:
zs.append(x // 26 * zdiv)
if 0 <= w - xcheck < 26:
z0 = z_final * zdiv
zs.append(w - xcheck + z0)
return zs
def solve(part, div_check_add):
zs = {0}
result = {}
if part == 1:
ws = list(range(1, 10))
else:
ws = list(range(9, 0, -1))
for zdiv, xcheck, yadd in div_check_add[::-1]:
newzs = set()
for w, z in it.product(ws, zs):
z0s = backward(xcheck, yadd, zdiv, z, w)
for z0 in z0s:
newzs.add(z0)
result[z0] = (w,) + result.get(z, ())
zs = newzs
return ''.join(map(str, result[0]))
input = open("input/24_2021.txt").read()
div_check_add = extract_parameters(input)
print(solve(1, div_check_add))
print(solve(2, div_check_add))
Here is another solution that use this approach:
Some people have visualized z as a base-26 number that behaves like a stack, where we push and pop values depending on some conditions. With this approach, we start from the first input value, and adjust it on the fly.
[the subroutine] can be re-written using an if-block, which eliminates the x-register:
if z % 26 + chk != inp:
z //= div
z *= 26
z += inp + add
else:
z //= div
We note that "div" can only be one of two value: either 1 or 26. This
leads us to observe that all computations are manipulations of digits
of the z-register written in base 26. So it's natural to define
"div = 26**shf", so that "shf" will be either 0 or 1. We can use
binary operators to denote operations in base 26 as follows:
z * 26 = z << 1
z // 26 = z >> 1
z % 26 = z & 1
With this we can write the program as follows:
if z & 1 + chk != inp:
z = z >> shf
z = z << 1
z = z + (inp + add)
else:
z = z >> shf
We can also write the bitwise operations as follows:
z & 1 = z.last_bit
z >> 1 = z.pop()
(z << 1) & a = z.push(a)
(z >> 1) << 1) & a = z.pop_push(a)
where pop/push refer to that bit stack of z in base 26 with the last
bit on top. Therefore, z.pop() removes the last bit, z.push(a) appends
the bit "a", and z.pop_push(a) replaces the last bit by "a".
Given that shf can only be 0 or 1 we get the following two cases:
if shf == 0:
if z.last_bit + chk != inp:
z.push(inp + add)
elif shf == 1:
if z.last_bit + chk != inp:
z.pop_push(inp + add)
else:
z.pop()
According to the puzzle input (our input) in all cases where shf == 0
it's true that chk > 9. Given that 1 <= inp <= 9 the check
(if z.last_bit + chk != inp) will therefore always be true. This gives:
if shf == 0:
z.push(inp + add)
elif shf == 1:
if z.last_bit + chk == inp:
z.pop()
else:
z.pop_push(inp + add)
We can summarize in words. View z as a stack of bits in base 26. Start
with an empty stack. Whenever shf == 0 (div == 1) push (inp + add) on
the stack. If, however, shf == 1, consider the last bit on the stack.
If it's equal to (inp - chk), then remove it, otherwise replace it by
(inp + add).
(Let’s be clear: it took me a looot of time to get a good grasp of this approach, it’s OK if it doesn’t make sense right away).
We can then solve the problem very quickly, by searching for the correct input value in each 14 block. Occasionnaly, we’ll have to correct another value later on. It leads to a very fast solution (source):
from common24 import extract_parameters
def solve(inp, div_check_add) -> str:
"""Solve the problem by regarding z as a base 26 number"""
zstack = []
for i, oc in enumerate(div_check_add):
zdiv, xcheck, yadd = oc
if zdiv == 1:
zstack.append((i, yadd))
elif zdiv == 26:
j, yadd = zstack.pop()
inp[i] = inp[j] + yadd + xcheck
if inp[i] > 9:
inp[j] = inp[j] - (inp[i] - 9)
inp[i] = 9
if inp[i] < 1:
inp[j] = inp[j] + (1 - inp[i])
inp[i] = 1
else:
raise (ValueError(f"unsupported div value: {zdiv}"))
assert (len(zstack) == 0), len(zstack) # the stack must be 0 in order for z to reach 0 at the end
return "".join(map(str, inp))
def part1(div_check_add):
"biggest accepted number"
return solve(list([9] * 14), div_check_add)
def part2(div_check_add):
"""smallest accepted number"""
return solve(list([1] * 14), div_check_add)
input = open("input/24_2021.txt").read()
div_check_add = extract_parameters(input)
print(part1(div_check_add))
print(part2(div_check_add))
Z3 is a theorem prover. If you’ve never worked with formal methods before, it’s quite surprising because you don’t have to write an algorithm that solves the problem, but you have to declare the constraints of your problem and z3 will find matching solutions. I wrote about it before, and here is a great tutorial on the matter.
Z3 is quite popular in security CTFs where people are tasked with reversing a hash, that is to say finding an input that provides a specific output with a known (but complicated) hashing algorithm. It’s exactly what we have to do here ; we can use Z3 to solve the simplified subroutines for us.
from z3 import *
from common24 import extract_parameters
def solve(div_check_add: list, part: int) -> int:
solver = Optimize()
z = 0 # this is our running z, which has to be zero at the start and end
# We have 14 inputs, they all are integers between 1 and 9 included
ws = [Int(f'w{i}') for i in range(14)]
for i in range(14):
solver.add(And(ws[i] >= 1, ws[i] <= 9))
# The value where we concatenate our input digits
digits_base_10 = Int(f"digits_base_10")
solver.add(digits_base_10 == sum((10 ** i) * d for i, d in enumerate(ws[::-1])))
# We implement the subroutine as a list of constraints, one for each of the 14 blocks:
for (i, [div, check, add]) in enumerate(div_check_add):
z = If(z % 26 + check == ws[i], z / div, z / div * 26 + ws[i] + add)
# The final z value must be zero
solver.add(z == 0)
if part == 1:
solver.maximize(digits_base_10)
else:
solver.minimize(digits_base_10)
assert (solver.check() == sat) # the solver must find a solution
return solver.model().eval(digits_base_10)
input = open("input/24_2021.txt").read()
div_check_add = extract_parameters(input)
print(solve(div_check_add, 1))
print(solve(div_check_add, 2))
I’m keeping this here because that’s clever. Mebeim had issues writing a 100% z3 solution, so he used GCC to simplify the equations:
- Rewrite the [input program by hand in C](https://github.com/mebeim/aoc/blob/master/2021/misc/day24/program.c) (with a bunch of macros) and let GCC compile it with maximum optimizations.
- Decompile the generated binary in IDA Pro (or [Ghidra](https://ghidra-sre.org/) if you want), which should give [a pretty good decompiled source with simplified equations](https://github.com/mebeim/aoc/blob/master/2021/misc/day24/decompiled.c ) (thanks GCC!).
- Copy paste [the equations into a new Z3 Python script](https://github.com/mebeim/aoc/blob/master/2021/original_solutions/day24.py) and solve for the maximum/minimum using the Z3 Optimizer solver, which this time can manage to work in a decent runtime with the simplified equations (~30s).
In the second part, we are studying approach that could work for any ALU program, even if they did not have the properties we studied in the first part (repeated code, x and y are useless, and so on).
The bruteforce way means going through all the possibilities. That O(n), but when n is large and the algorithm can be slow, it can take a lot of time. The problem itself involves going through 9^14 possibilities. The naive bruteforce approach could mean at least a few days of computations.
Matt Keeter wrote a detailled article about how he used code generation, state deduplication and multithreading in order to speed it up.
DFS in itself was to slow, so memoization helped avoid recomputing similar steps in other subroutines. This is not totally generic since it takes advantage of the fact that the program consists in 14 repeated programs… but that clever anyway.
I noticed the input code had 14 very similar chunks that each took a single input, and the only thing that mattered about the state between chunks was z and input w registers -- x and y are zeroed out each time. So I basically just did a DFS with depth 14 and memoization on the branches, trying to find z=0 at the final step. And I added a cache for the execution to speed up part 2. Part 2 took 8 minutes runtime.
https://github.com/WilliamLP/AdventOfCode/blob/master/2021/day24.py
The use of Z3 for this kind of reversing problem is quite popular in CTF contests. I was surprised not to see it more often. With it, we can entirely bruteforce the problem by simply describing the problem and adding constraints for the result of each instruction.
This is very similar to the previous z3 approach, except we implement all the instructions instead of the simplified routine. Interestingly, both z3 solution run in about the same time.
import z3
from common24 import parse_code
# https://www.reddit.com/r/adventofcode/comments/rnejv5/2021_day_24_solutions/hpshymr/
alu_program = parse_code(open('input/24_2021.txt', 'r').read())
solver = z3.Optimize()
# Two constants we’ll need later on
zero, one = z3.BitVecVal(0, 64), z3.BitVecVal(1, 64)
# our input digits
digits = [z3.BitVec(f'd_{i}', 64) for i in range(14)]
for d in digits:
solver.add(z3.And(1 <= d, d <= 9))
digit_input = iter(digits)
# the base 10 result where we concatenate the values
digits_base_10 = z3.BitVec(f"digits_base_10", 64)
solver.add(digits_base_10 == sum((10 ** i) * d for i, d in enumerate(digits[::-1])))
# Now we implement the entire ALU program through Z3
registers = {r: zero for r in 'xyzw'}
for i, instruction in enumerate(alu_program):
# for every instruction, we create an intermediate value whose constraint is to
# hold the result of this instruction
if instruction[0] == 'inp':
registers[instruction[1]] = next(digit_input)
else:
register, operand = instruction[1:]
operand = registers[operand] if operand in registers else int(operand)
instruction_i = z3.BitVec(f'instruction_{i}', 64)
if instruction[0] == 'add':
solver.add(instruction_i == registers[register] + operand)
elif instruction[0] == 'mul':
solver.add(instruction_i == registers[register] * operand)
elif instruction[0] == 'mod':
solver.add(registers[register] >= 0)
solver.add(operand > 0)
solver.add(instruction_i == registers[register] % operand)
elif instruction[0] == 'div':
solver.add(operand != 0)
solver.add(instruction_i == registers[register] / operand)
elif instruction[0] == 'eql':
solver.add(instruction_i == z3.If(registers[register] == operand, one, zero))
else:
assert False
registers[register] = instruction_i
solver.add(registers['z'] == 0)
for f in (solver.maximize, solver.minimize):
solver.push()
f(digits_base_10)
assert(solver.check() == z3.sat)
print(solver.model().eval(digits_base_10))
solver.pop()
Some warriors decided to implement an home made symbolic calculator, because why not?