Challenge Description
SusCipher 400 pts (6 solves)
authored by rbtree
I made SusCipher, which is a vulnerable block cipher so everyone can break it!
Please, try it and find a key. nc suscipher.chal.ctf.acsc.asia 13579 nc suscipher-2.chal.ctf.acsc.asia 13579 (Backup) Hint: Differential cryptanalysis is useful. SusCipher.tar.gz
Source files
Source Analysis
#!/usr/bin/env python3
import hashlib
import os
import signal
class SusCipher:
S = [
43, 8, 57, 53, 48, 39, 15, 61,
7, 44, 33, 9, 19, 41, 3, 14,
42, 51, 6, 2, 49, 28, 55, 31,
0, 4, 30, 1, 59, 50, 35, 47,
25, 16, 37, 27, 10, 54, 26, 58,
62, 13, 18, 22, 21, 24, 12, 20,
29, 38, 23, 32, 60, 34, 5, 11,
45, 63, 40, 46, 52, 36, 17, 56
]
P = [
21, 8, 23, 6, 7, 15,
22, 13, 19, 16, 25, 28,
31, 32, 34, 36, 3, 39,
29, 26, 24, 1, 43, 35,
45, 12, 47, 17, 14, 11,
27, 37, 41, 38, 40, 20,
2, 0, 5, 4, 42, 18,
44, 30, 46, 33, 9, 10
]
ROUND = 3
BLOCK_NUM = 8
MASK = (1 << (6 * BLOCK_NUM)) - 1
@classmethod
def _divide(cls, v: int) -> list[int]:
l: list[int] = []
for _ in range(cls.BLOCK_NUM):
l.append(v & 0b111111)
v >>= 6
return l[::-1]
@staticmethod
def _combine(block: list[int]) -> int:
res = 0
for v in block:
res <<= 6
res |= v
return res
@classmethod
def _sub(cls, block: list[int]) -> list[int]:
return [cls.S[v] for v in block]
@classmethod
def _perm(cls, block: list[int]) -> list[int]:
bits = ""
for b in block:
bits += f"{b:06b}"
buf = ["_" for _ in range(6 * cls.BLOCK_NUM)]
for i in range(6 * cls.BLOCK_NUM):
buf[cls.P[i]] = bits[i]
permd = "".join(buf)
return [int(permd[i : i + 6], 2) for i in range(0, 6 * cls.BLOCK_NUM, 6)]
@staticmethod
def _xor(a: list[int], b: list[int]) -> list[int]:
return [x ^ y for x, y in zip(a, b)]
def __init__(self, key: int):
assert 0 <= key <= self.MASK
keys = [key]
for _ in range(self.ROUND):
v = hashlib.sha256(str(keys[-1]).encode()).digest()
v = int.from_bytes(v, "big") & self.MASK
keys.append(v)
self.subkeys = [self._divide(k) for k in keys]
def encrypt(self, inp: int) -> int:
block = self._divide(inp)
block = self._xor(block, self.subkeys[0])
for r in range(self.ROUND):
block = self._sub(block)
block = self._perm(block)
block = self._xor(block, self.subkeys[r + 1])
return self._combine(block)
# TODO: Implement decryption
def decrypt(self, inp: int) -> int:
raise NotImplementedError()
def handler(_signum, _frame):
print("Time out!")
exit(0)
def main():
signal.signal(signal.SIGALRM, handler)
signal.alarm(300)
key = int.from_bytes(os.urandom(6), "big")
cipher = SusCipher(key)
while True:
inp = input("> ")
try:
l = [int(v.strip()) for v in inp.split(",")]
except ValueError:
print("Wrong input!")
exit(0)
if len(l) > 0x100:
print("Long input!")
exit(0)
if len(l) == 1 and l[0] == key:
with open('flag', 'r') as f:
print(f.read())
print(", ".join(str(cipher.encrypt(v)) for v in l))
if __name__ == "__main__":
main()
Let’s take a look at the relevant parts
main
While True, it asks for an input which is a string of numbers separated by ,
As long as we input 0x100
or 256 numbers at a time, we can get as many encryptions as we like
If we only enter a single number, and if that number happens to be the secret round key, we can get the flag
Sounds easy? lets take a look into the cipher
The Cipher
The above construction is Substitution Permutation Network (SPN) which is essentially a repeated operation of substitution with a fixed predefined array
here which is
S = [
43, 8, 57, 53, 48, 39, 15, 61,
7, 44, 33, 9, 19, 41, 3, 14,
42, 51, 6, 2, 49, 28, 55, 31,
0, 4, 30, 1, 59, 50, 35, 47,
25, 16, 37, 27, 10, 54, 26, 58,
62, 13, 18, 22, 21, 24, 12, 20,
29, 38, 23, 32, 60, 34, 5, 11,
45, 63, 40, 46, 52, 36, 17, 56
]
I.e a number n
is replaced with S[n]
followed by a permutation of bits here defined by
P = [
21, 8, 23, 6, 7, 15,
22, 13, 19, 16, 25, 28,
31, 32, 34, 36, 3, 39,
29, 26, 24, 1, 43, 35,
45, 12, 47, 17, 14, 11,
27, 37, 41, 38, 40, 20,
2, 0, 5, 4, 42, 18,
44, 30, 46, 33, 9, 10
]
Which means P[t]
th bit of the output is actually the t
th bit of inpute.g. P[0] = 21
means the 21
th bit of output is 0
th bit of input as is
followed by a xor operation with a secret key
This process is repeated a fixed number of rounds times with a new secret key each round (which are called the round keys or subkeys)
The encrypt
function hence looks like as below
def encrypt(self, inp: int) -> int:
block = self._divide(inp)
block = self._xor(block, self.subkeys[0])
for r in range(self.ROUND):
block = self._sub(block)
block = self._perm(block)
block = self._xor(block, self.subkeys[r + 1])
return self._combine(block)
(_divide
and _combine
are just helper functions to make programmers life easier)
One might question, why are we _dividing
a good enough input of 48 bits into 8 chunks of 6 bits each?
Well, in an ideal world, we would like to have a substitution box of 48 bits, but that would eat up a whopping 2^48
number of entries (which we are somehow fooling with 2^6
entries here
Hence the functions _sub
acts as if it sees 8 different values and substitutes them and acts as if it just did 48 bits of substitution
@classmethod
def _sub(cls, block: list[int]) -> list[int]:
return [cls.S[v] for v in block]
So does _perm
pretend (because of our design) that it sees a big block of 48 bits which it permutes to a block of 48 bits, but what it does is to take 8 blocks of 6 bits each and create 8 blocks of 6 bits each if they were all connected
@classmethod
def _perm(cls, block: list[int]) -> list[int]:
bits = ""
for b in block:
bits += f"{b:06b}"
buf = ["_" for _ in range(6 * cls.BLOCK_NUM)]
for i in range(6 * cls.BLOCK_NUM):
buf[cls.P[i]] = bits[i]
permd = "".join(buf)
return [int(permd[i : i + 6], 2) for i in range(0, 6 * cls.BLOCK_NUM, 6)]
Where do subkeys come from?
def __init__(self, key: int):
assert 0 <= key <= self.MASK
keys = [key]
for _ in range(self.ROUND):
v = hashlib.sha256(str(keys[-1]).encode()).digest()
v = int.from_bytes(v, "big") & self.MASK
keys.append(v)
self.subkeys = [self._divide(k) for k in keys]
As you may have observed from the init function, subkeys are “derived” from a single 48-bit key in a way that we cant recover subkey i from the knowledge of any of the subkeys j>i (to make the challenge hard so that we will definitely need to recover subkey[0] which is the original key
Vulnerability?
If you have seen some cipher constructions before, you may have observed, that the ROUND = 3
is really very low and 6-bit
sboxes are still not as robust as you may imagine them to be.
Another hint as provided by the author is Differential Cryptanalysis, and since I am obsessed with SAT solvers, I will overlook the hint and cheeze it with z3
Modelling
While the general methodology to solve a problem with a SAT solver is to write the output as a (symbolic) function of the inputs, and finding an input which leads to the observed output.
So what’s the symbolic input and output here?
For an input inp
to the SusCipher(key)
producing an encryption out
We can write out
as symbolic_function(subkeys, inp)
With subkeys
acting as unknown inp
which we aim for, we can easily get the desired outcome.
Taking heavy inspiration from the implementaion of the challenge cipher, we can similary create the z3 model of suscipher
class CrackSusCipher:
ROUND = 3
BLOCK_NUM = 8
def _divide(self, v):
l = []
for _ in range(self.BLOCK_NUM):
l.append(v&0b111111)
v >>=6
return l[::-1]
def _combine(self, block):
res = 0
for v in block:
res <<=6
res |= v
return res
def _xor(self, a,b):
return [x^y for x,y in zip(a,b)]
These functions look identical.
Modelling substitution
First hurdle most of the people face modelling a SPN network or any other cipher is to model substitution.
But z3 is equipped with powerful theories of arrays (and functions)
Thus to model substitution, we can define a symbolic function S, which takes 6-bit inputs and generates 6-bit outputs
self.S = Function('S', BitVecSort(6), BitVecSort(6))
then self.S(i)
would indeed be exactly what we desire
But wait, we just specified that S
can be any function, not the exact substitution function we are provided with.
Worry not, we can specify this as a constraint to the solver
for i,v in enumerate(S): #original S as provided in the challenge
self.solver.add(self.S(i)==v)
i.e we want S(0)
to be nothing else than 43 and so on
And we treat keys as 6 bit unknowns, so there will be (ROUND+1)*8
variables.
Overall our init function will look like
def __init__(self):
self.S = Function('S', BitVecSort(6), BitVecSort(6))
self.solver = Solver()
for i,v in enumerate(S):
self.solver.add(self.S(i)==v)
self.keys = [[BitVec(f'k_{r}_{i}',6) for i in range(8) ] for r in range(self.ROUND+1)]
hence _sub
function would be
def _sub(self, block):
return [self.S(simplify(i)) for i in block]
Note that it could have been self.S(i)
instead of self.S(simplify(i))
which I used, just to simplify the expression (if possible) before substituting to hopefully speed things up
Modelling permutation
Now what about the permutation? We can model it exactly how we would have calculated a permutation
Take the i
th bit, put it P[i]
th place in the output, just the way to deal with BitVectors vary
def _perm(self, block):
x = Concat(block)
# treat the 8 6-bit vectors as a single 48 bit-vector
output = [0]*48 # temporary placeholder for output
for i,v in enumerate(P):
# extract the ith bit from the MSB put it at the correct place
output[v] = Extract(47-i, 47-i, x)
# rechunk in 6 bit bitvectors
return [Concat(output[i:i+6]) for i in range(0,48,6)]
Modelling encryption
Finally after getting the required blocks to perform our symbolic encryption, we can model it
def enc(self, block):
block = self._xor(block, self.keys[0])
for r in range(self.ROUND):
block = self._sub(block)
block = self._perm(block)
block = self._xor(block, self.keys[r+1])
return block
Which you can see is almost like the original except we are not _dividing and _combining the 48-bits but rather assume that it operates on 8 6-bit values. And self.keys
here are the symbolic unknowns.
Checking if our model is correct
Now a CTF player will be anxious whether the efforts they put in to model the cipher were fruitful or did they mess up the model somewhere?
Worry not, we can always check our symbolic model by plugging in real values and comparing with the original cipher
We will use random values and keys just to check if they match (kinda funny that we have to informally verify a formal verifier XD)
print("verifying our modelling")
import random
for i in range(100):
random_key = random.randint(0,2**48-1)
sus = SusCipher(random_key)
sus_model = crack()
sus_model.solver.check() # to fill in the `S` as the original substitution function
sus_model.keys = [[BitVecVal(i,6) for i in row] for row in sus.subkeys] # BitVecVal as a symbolic constant value
for j in range(10):
inp = random.randint(0,2**48-1)
real_out = sus.encrypt(inp)
sym_out_chunks = sus_model.enc(sus_model._divide(inp))
# evaluating the symbolic output as per the symbolic model
sym_out = sus_model.solver.model().eval(Concat(sym_out_chunks))
assert sym_out.as_long() == real_out
print("success")
Adding input output points
Taking care of the _divide business, we will equate the 6-bit chunks of the output and our symbolic output for a given input
def add_sample(self, inp, oup):
for a,b in zip(self.enc(self._divide(inp)), self._divide(oup)):
self.solver.add(a==b)
Getting the key
It’s really simple, just check if there is any satisfying model which would make our constraints possible, and get the first subkey according to that model
def get(self):
if self.solver.check()==sat:
model = self.solver.model()
k = [model.eval(i).as_long() for i in self.keys[0]]
return self._combine(k)
Putting our class together
S = [
43, 8, 57, 53, 48, 39, 15, 61,
7, 44, 33, 9, 19, 41, 3, 14,
42, 51, 6, 2, 49, 28, 55, 31,
0, 4, 30, 1, 59, 50, 35, 47,
25, 16, 37, 27, 10, 54, 26, 58,
62, 13, 18, 22, 21, 24, 12, 20,
29, 38, 23, 32, 60, 34, 5, 11,
45, 63, 40, 46, 52, 36, 17, 56
]
P = [
21, 8, 23, 6, 7, 15,
22, 13, 19, 16, 25, 28,
31, 32, 34, 36, 3, 39,
29, 26, 24, 1, 43, 35,
45, 12, 47, 17, 14, 11,
27, 37, 41, 38, 40, 20,
2, 0, 5, 4, 42, 18,
44, 30, 46, 33, 9, 10
]
class crack:
ROUND = 3
BLOCK_NUM = 8
def __init__(self):
self.S = Function('S', BitVecSort(6), BitVecSort(6))
self.solver = Solver()
for i,v in enumerate(S):
self.solver.add(self.S(i)==v)
self.keys = [[BitVec(f'k_{r}_{i}',6) for i in range(8) ] for r in range(self.ROUND+1)]
def _divide(self, v):
l = []
for _ in range(self.BLOCK_NUM):
l.append(v&0b111111)
v >>=6
return l[::-1]
def _combine(self, block):
res = 0
for v in block:
res <<=6
res |= v
return res
def _xor(self, a,b):
return [x^y for x,y in zip(a,b)]
def _perm(self, block):
x = Concat(block)
output = [0]*48
for i,v in enumerate(P):
output[v] = Extract(47-i, 47-i, x)
return [Concat(output[i:i+6]) for i in range(0,48,6)]
def _sub(self, block):
return [self.S(simplify(i)) for i in block]
def enc(self, block):
block = self._xor(block, self.keys[0])
for r in range(self.ROUND):
block = self._sub(block)
block = self._perm(block)
block = self._xor(block, self.keys[r+1])
return block
def add_sample(self, inp, oup):
for a,b in zip(self.enc(self._divide(inp)), self._divide(oup)):
self.solver.add(a==b)
def get(self):
if self.solver.check()==sat:
model = self.solver.model()
k = [model.eval(i).as_long() for i in self.keys[0]]
return self._combine(k)
So how many input-output pairs do we need to figure out the key uniquely?
I guess atmost 256?
Let’s try it out
c = crack()
for i in range(256):
input = random.randint(0,2**48-1)
output = server(input) #whatever
c.add_sample(input, output)
key = c.get()
Hmmm, something’s not right, it seems to be stuck indefinitely.
We can get the intuition of difficulty of the solver to find key by reducing the number of constraints i.e the number of input output pairs.
By playing around, one quickly comes to the realisation that it wont workeven for 5 random samples and will time out >200s
Moment of inspiration
How about we address the difficulty of the solver (by addressing the difficulty of the problem being asked to solve)
When we take a random input-output pair, what we ask the solver for Substitution(key[i] ^ some_random)
But if it were just 0
instead of some_random, it would have to guess one less step.
So how about we make 7 out of 8 0
and only keep one key
place active in substitution?.
This is really easy with input = (1<<i)
for (0<=i<48)
And most importantly, it works!
(To an amazement that it works in around a second with 48 samples as opposed to ~5000 seconds for 5 random samples!)
Getting the flag
import pwn
HOST, PORT = "suscipher.chal.ctf.acsc.asia", 13579
REM = pwn.remote(HOST, PORT)
REM.sendline(",".join(str((1<<i)) for i in range(48)))
response = list(map(int,REM.recvline()[2:].strip().split(b', ')))
c = crack()
for i,v in enumerate(response):
c.add_sample(1<<i,v)
key = c.get()
REM.sendline(str(key))
REM.interactive()
ACSC{There_may_be_a_better_solution_to_solve_this_but_I_used_diff_analysis_:(}
As expected, the author knew there might be other interesting ways like this one ;)