Tasty Crypto Roll
Description
CRYPTO - Hard
Tasty Crypto Roll
Bob, the genius intern at our company, invented AES-improved. It is based on AES but with layers after layers of proprietary encryption techniques on top of it.The end result is an encryption scheme that achieves both confusion and diffusion. The more layers of crypto you add, the better the security, right?
Encrypter
encrypt.py
Encrypted file
enc.bin
Note
The intended solution requires very little brute force and runs under 5 seconds on our machine.
By k3v1n
Source
import os
import random
import secrets
import sys
from Crypto.Cipher import AES
ENCODING = 'utf-8'
def generate_key():
return os.getpid(), secrets.token_bytes(16)
def to_binary(b: bytes):
return ''.join(['{:08b}'.format(c) for c in b])
def from_binary(s: str):
return bytes(int(s[i:i+8], 2) for i in range(0, len(s), 8))
def encrypt(key: bytes, message: bytes):
cipher = AES.new(key, AES.MODE_ECB)
return cipher.encrypt(message)
key1, key2 = generate_key()
print(f'Using Key:\n{key1}:{key2.hex()}')
def get_flag():
flag = input('Enter the flag to encrypt: ')
if not flag.startswith('sdctf{') or not flag.endswith('}') or not flag.isascii():
print(f'{flag} is not a valid flag for this challenge')
sys.exit(1)
return flag
plaintext = get_flag()[6:-1]
data = plaintext.encode(ENCODING)
codes = list(''.join(chr(i) * 2 for i in range(0xb0, 0x1b0)))
random.seed(key1)
random.shuffle(codes)
sboxes = [''.join(codes[i*4:(i+1)*4]) for i in range(128)]
if len(set(sboxes)) < 128:
print("Bad key, try again")
sys.exit(1)
data = ''.join(sboxes[c] for c in data).encode(ENCODING)
data = encrypt(key2, to_binary(data).encode(ENCODING))
random.seed(key1)
key_final = bytes(random.randrange(256) for _ in range(16))
data_bits = list(to_binary(data))
random.shuffle(data_bits)
data = from_binary(''.join(data_bits))
ciphertext = encrypt(key_final, data)
print(f'Encrypted: {ciphertext.hex()}')
with open('enc.bin2', 'wb') as ef:
ef.write(ciphertext)
Analysis
Here we can see mainly two parts
- There are two keys
key1
: pid of current processkey2
: secure random key of 16 bytes
key1
is used as seed at a lot of places and is bruteforcable (< 2^15)key_final
andsboxes
are derived fromkey1
, shuffling is done usingkey1
Steps to crack
- decrypt using
key_final
- convert the intermediate ciphertext
to_binary
- de-shuffle the bits
- generate
from_binary
intermediate ciphertext of the deshuffled bits - decrypt using
key2
???
How to find key1
?
Assume you have the correct key1
, reverse for the key, validate the results
using some validator/logical assumption.
codes
is a list of 2*(0x1b0-0xb0)
= 512
characters, utf-8 encoding of
which is 2-bytes each
sboxes
will have 4char strings, which encode to 8 bytes each on utf-8 (i.e
after substitution)
data
is now 4*2 = 8
times each byte of the original plaintext
data
is converted to_binary
before encryption hence each byte is converted
to 8 b"0"
or b"1"
byte. Hence each character is substituted to some
8*8 = 64
byte string before encryption.
Hence len of flag = len(ciphertext)//64
= 3520//64 = 55
bytes
Assumption 1
Since length of flag is 55 characters, would it be reasonable to assume that
there would be repeatitions of characters. And since each flag character is
substituted to fixed 64-byte strings before encryption which is a multiple
of AES block size of 16, AES also acts like simple substitution of the flag
but we do not know the mapping.
Hence if we reverse till step 4 above, we can simply check if there are any
repeating 64-byte blocks, as incorrect shuffling of bits will result in each
block to be distinct with almost 1 probability.
with open('enc.bin', 'rb') as f:
ciphertext = f.read()
def to_binary(b: bytes):
return ''.join(['{:08b}'.format(c) for c in b])
def from_binary(s: str):
return bytes(int(s[i:i+8], 2) for i in range(0, len(s), 8))
def encrypt(key, message):
return AES.new(key, AES.MODE_ECB).encrypt(message)
def decrypt(key: bytes, message: bytes):
return AES.new(key, AES.MODE_ECB).decrypt(message)
def unshuffle(data_list, shuffle_order):
res = [None]*len(data_list)
for i,v in enumerate(shuffle_order):
res[v] = data_list[i]
return res
def key_final_dec(key1, ciphertext):
random.seed(key1)
key_final = bytes(random.randrange(256) for _ in range(16))
data = decrypt(key_final, ciphertext)
data_bits = list(to_binary(data))
data_bits_order = list(range(len(data_bits)))
random.shuffle(data_bits_order)
data_bits_uns = unshuffle(data_bits, data_bits_order)
data = from_binary(''.join(data_bits_uns))
return data
Lets add a few validation too
def key_final_enc(key1, data):
random.seed(key1)
key_final = bytes(random.randrange(256) for _ in range(16))
data_bits = list(to_binary(data))
random.shuffle(data_bits)
data = from_binary(''.join(data_bits))
return encrypt(key_final, data)
def test_unshuffle():
random_text = list(random.randbytes(16*1337))
random_text_shuffled = random_text.copy()
shuffle_order = list(range(len(random_text)))
random.seed(1337)
random.shuffle(random_text_shuffled)
random.seed(1337)
random.shuffle(shuffle_order)
assert unshuffle(random_text_shuffled, shuffle_order) == random_text
def test_key_final_dec():
random_text = random.randbytes(16*100)
assert key_final_dec(1337, key_final_enc(1337, random_text)) == random_text
test_unshuffle()
test_key_final_dec()
Looks like all the decryption functions are correct, lets proceed with
bruteforcing for key1
for key1 in tqdm(range(2**15),desc='solving for key1'):
data = key_final_dec(key1, ciphertext)
substitutions = Counter(data[i:i+64] for i in range(0,len(data),64))
if len(substitutions)!=len(data)//64:
print("pid =",key1)
break
After waiting for an eternity, and exhausting the search space of possible pid’s
yet not getting any key1
got me confused. I checked my script locally for a
test flag it seemed to work fine. There could only be one possibility
the flag contains 55 distinct characters
But how would I find key1
now?
Missed Catch
@Utaha#6878 pointed out, that since there are only 256 distinct values in
codes
each repeated twice, and each character encoded to some b"0"
or b"1"
byte strings of length 16, It must be encrypted to the same block always.
Since the flag is 55*4 = 220
such 16-byte codes and each code is used twice
for most of the characters, there will be repating 16-byte blocks even with
distinct flag characters.
Assumption 2
for key1 in tqdm(range(2**15),desc='solving for key1'):
data = key_final_dec(key1, ciphertext)
substitutions = Counter(data[i:i+16] for i in range(0,len(data),16))
if len(substitutions)!=len(data)//16:
print("pid =",key1)
break
pid = 83
And we found our key1
!
And we can confirm that the flag is indeed 55 distinct characters.
Wait, if the flag is 55 distinct characters, how will we solve for the subs?
We have no statistical advantage and hence bye bye Mr
quipquip
How do we find mapping for substitution?
Each sbox
entry is composed of 4 2-byte strings, which can be one of 256
possible values. Moreover, their order is fixed, which is determined by key1
.
If we try to solve for all valid mappings for AES(binary(sbox(char)))
we will
probably end up on the correct mapping and get our flag.
+---------------+---------------+------------------------+---------------+
|flag0 | flag1 | | flag55 |
+---------------+---------------+ .... +---------------+
| sbox | sbox | | sbox |
+---+---+---+---+---+---+---+---+------------------------+---------------+
|c1 |c2 |c3 |c4 |c5 |c6 |c7 |c8 | | |
| | | | | | | | | | |
+---+---+---+---+---+---+---+---+ .... +---------------+
| AES | AES | | |
+---+---+-------+---------------+------------------------+---------------+
| | +------+
| +--+ |
+------+-------+-------+------+
|E(c1) | E(c2) | E(c3) | E(c4)|
+------+-------+-------+------+
Enter Z3
We can assume our flag to be a list of BitVec
of 7 bits each
And let the sboxes be a mapping from 7 bits to 64 bits each (16x4)
This can be achieved by assuming sbox to be an array which is indexed
by BitVec(7)
and contains elements of BitVec(64)
And we assume AES to be some function form BitVec(16)
to BitVec(128)
flag = [BitVec('flag_'+str(i),7) for i in range(len(data)//64)]
sboxmap = Array('sbox',BitVecSort(7), BitVecSort(64))
aes_encryption = Function('AES',BitVecSort(16), BitVecSort(128))
codes = list(''.join(chr(i) * 2 for i in range(0xb0, 0x1b0)))
random.seed(key1)
random.shuffle(codes)
# keeping sboxes utf encoded already
sboxes = [''.join(codes[i*4:(i+1)*4]).encode() for i in range(128)]
sbytes = b''.join(sboxes)
sboxints = list(map(lambda x:int.from_bytes(x,'big'),
set(sbytes[i:i+2] for i in range(0,len(sbytes),2))))
# integer values for 2-byte codes from sbox, will be explained shortly
sboxes = [int.from_bytes(i,'big') for i in sboxes]
data = key_final_dec(key1, ciphertext)
# converting intermediate decryption to 128 bit ints
data_int = []
for i in range(0,len(data),16):
data_int.append(int.from_bytes(data[i:i+16],'big'))
# we know the sbox already
constraints = [sboxmap[i]==sboxes[i] for i in range(128)]
for i in range(len(data)//64):
four_code = sboxmap[flag[i]]
# splitting 64 bit quantity to 16 bit individual sbox codes
four_code_parts = [Extract(16*i+15,16*i,four_code) for i in range(3,-1,-1)]
# for each code, matching aes_encryption with the observed value
for a,b in zip(data_int[4*i:4*i+4], four_code_parts):
constraints.append(aes_encryption(b)==a)
# last but not least, aes_encryption(i) is unique for each plaintext
# how would z3 know? Distinct function encodes them appropriately to
# be distinct
constraints.append(Distinct([aes_encryption(i) for i in sboxints]))
solver = Solver()
solver.add(constraints)
for m in all_smt(solver, flag):
# lets check for all satisfying flags (in case there are more than one
# possible mappings and we will rule out invalid ones in that scenario?
flag_bytes = bytes([m.eval(flag[i]).as_long() for i in range(len(flag))])
assert len(set(flag_bytes)) == len(Counter(data[i:i+64] for i in range(0,len(data),64)))
print(flag_bytes)
Flag
After running the script, we finally get our flag!
b'r0l1-uR~pWn.c6yPtO_wi7h,ECB:I5*b8d!KQvJmLxgX9DsaANMFSeU'
And it turns out to be the only satisfying assignment.
Turns out if there were repeated characters in the flag, we will get multiple
possible satisfying values. So the admins have not been so cheeky afterall
Full script
Note that it takes a couple of seconds to find the z3 model
import random
from Crypto.Cipher import AES
from collections import Counter
from tqdm import tqdm
from z3 import *
import sys
def all_smt(s, initial_terms):
def block_term(s, m, t):
s.add(t != m.eval(t))
def fix_term(s, m, t):
s.add(t == m.eval(t))
def all_smt_rec(terms):
if sat == s.check():
m = s.model()
yield m
for i in range(len(terms)):
s.push()
block_term(s, m, terms[i])
for j in range(i):
fix_term(s, m, terms[j])
yield from all_smt_rec(terms[i:])
s.pop()
yield from all_smt_rec(list(initial_terms))
with open('enc.bin', 'rb') as f:
ciphertext = f.read()
def to_binary(b: bytes):
return ''.join(['{:08b}'.format(c) for c in b])
def from_binary(s: str):
return bytes(int(s[i:i + 8], 2) for i in range(0, len(s), 8))
def encrypt(key, message):
return AES.new(key, AES.MODE_ECB).encrypt(message)
def decrypt(key: bytes, message: bytes):
return AES.new(key, AES.MODE_ECB).decrypt(message)
def key_final_enc(key1, data):
random.seed(key1)
key_final = bytes(random.randrange(256) for _ in range(16))
data_bits = list(to_binary(data))
random.shuffle(data_bits)
data = from_binary(''.join(data_bits))
return encrypt(key_final, data)
def unshuffle(data_list, shuffle_order):
res = [None] * len(data_list)
for i, v in enumerate(shuffle_order):
res[v] = data_list[i]
return res
def test_unshuffle():
random_text = list(random.randbytes(16 * 100))
random_text_shuffled = random_text.copy()
shuffle_order = list(range(len(random_text)))
random.seed(10)
random.shuffle(random_text_shuffled)
random.seed(10)
random.shuffle(shuffle_order)
assert unshuffle(random_text_shuffled, shuffle_order) == random_text
test_unshuffle()
def key_final_dec(key1, ciphertext):
random.seed(key1)
key_final = bytes(random.randrange(256) for _ in range(16))
data = decrypt(key_final, ciphertext)
data_bits = list(to_binary(data))
data_bits_order = list(range(len(data_bits)))
random.shuffle(data_bits_order)
data_bits_uns = unshuffle(data_bits, data_bits_order)
data = from_binary(''.join(data_bits_uns))
return data
def test_key_final_dec():
random_text = random.randbytes(16 * 100)
assert key_final_dec(10, key_final_enc(10, random_text)) == random_text
test_key_final_dec()
for key1 in tqdm(range(2**15), desc='solving for key1'):
data = key_final_dec(key1, ciphertext)
substitutions = Counter(data[i:i + 16] for i in range(0, len(data), 16))
if len(substitutions) != len(data) // 16:
print("pid =", key1)
break
codes = list(''.join(chr(i) * 2 for i in range(0xb0, 0x1b0)))
random.seed(key1)
random.shuffle(codes)
sboxes = [''.join(codes[i * 4:(i + 1) * 4]).encode() for i in range(128)]
sbytes = b''.join(sboxes)
sboxints = list(map(lambda x: int.from_bytes(x, 'big'), set(
sbytes[i:i + 2] for i in range(0, len(sbytes), 2))))
sboxes = [int.from_bytes(i, 'big') for i in sboxes]
data = key_final_dec(key1, ciphertext)
data_int = []
for i in range(0, len(data), 16):
data_int.append(int.from_bytes(data[i:i + 16], 'big'))
flag = [BitVec('flag_' + str(i), 7) for i in range(len(data) // 64)]
sboxmap = Array('sbox', BitVecSort(7), BitVecSort(64))
aes_encryption = Function('AES', BitVecSort(16), BitVecSort(128))
constraints = [sboxmap[i] == sboxes[i] for i in range(128)]
for i in range(len(data) // 64):
four_code = sboxmap[flag[i]]
four_code_parts = [Extract(16 * i + 15, 16 * i, four_code)
for i in range(3, -1, -1)]
for a, b in zip(data_int[4 * i:4 * i + 4], four_code_parts):
constraints.append(aes_encryption(b) == a)
constraints.append(Distinct([aes_encryption(i) for i in sboxints]))
solver = Solver()
solver.add(constraints)
# if solver.check() == sat:
# m = solver.model()
for m in all_smt(solver, flag):
flag_bytes = bytes([m.eval(flag[i]).as_long() for i in range(len(flag))])
assert len(set(flag_bytes)) == len(
Counter(data[i:i + 64] for i in range(0, len(data), 64)))
print(flag_bytes)
else:
print("failed to solve")
Alternate Solution by teammate (Utaha#6878)
All due regards to him for solving the challenge while I was stuck over finding
key1
XD
All parts will be almost same except the substitution solving part, which he
did by manual bruteforcing i.e. recursively enumerating all mappings and
backtracking on contradictions
mp = dict()
codes = sum([[i, i] for i in range(256)], start=[])
# notice that the range is changed from [0xb0, 0x1b0) to [0, 256).
# It's just for relabeling.
random.seed(key1)
random.shuffle(codes)
sboxes = [codes[i*4:(i+1)*4] for i in range(128)]
def match(a, b):
"""
equate two objects elementwise ignoring if the entry is -1
"""
for x, y in zip(a, b):
if x == -1 or y == -1:
continue
if x != y:
return False
return True
answers = []
def getFlag(cip, sboxes, mp):
# get the flag based on current mapping, unknown char will be shown as '?'
res = []
for c in cip:
afterMap = [mp.get(x, -1) for x in c]
found = False
for i, s in enumerate(sboxes):
if s == afterMap:
res.append(i)
found = True
break
if not found:
res.append(ord('?'))
return bytes(res)
def brute(cip, sboxes, mp):
"""
cip and sboxes remain unchanged throughout the recursive call,
but I feel bad using global varaibles.
"""
if DEBUG:
print(getFlag(cip, sboxes, mp))
# check is finished
isFinished = True
for c in cip:
if all(x in mp for x in c):
pass
else:
isFinished = False
if isFinished:
answers.append(getFlag(cip, sboxes, mp))
print("Found an answer!!!!!!!")
return
# try matching
isContradiction = False
mp = mp.copy()
# Find the one with least possible matches.
min_pos = 256
index = -1
for idx, c in enumerate(cip):
afterMap = [mp.get(x, -1) for x in c]
if -1 not in afterMap:
continue
matches = [s for s in sboxes if match(s, afterMap)]
if len(matches) == 0:
isContradiction = True
break
if min_pos > len(matches):
index = idx
min_pos = len(matches)
if isContradiction:
return
# now bruteforce all possibilities
assert index != -1
afterMap = [mp.get(x, -1) for x in cip[index]]
matches = [s for s in sboxes if match(s, afterMap)]
for m in matches:
for x, y in zip(cip[index], m):
mp[x] = y
brute(cip, sboxes, mp)
# This is based on the repetition
for _ in [132, 197]:
mp = {35: 224, 109: 144, 4: _}
brute(cip, sboxes, mp)
print("Answers:")
answers = list(set(answers))
for x in answers:
print(b"sdctf{" + x + b"}")
# The fourth one is the actual answer
Ciphertext repetition: [4, 5, 4, 6] [34, 35, 36, 35] [109, 60, 110, 109] Sbox repetition: [132, 93, 132, 211] [197, 32, 197, 248] [144, 86, 67, 144] [165, 224, 27, 224] Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Found an answer!!!!!!! Answers: b'sdctf{r0l1-LR~pWn.c6yPtO_wi7h,ECB:I5*b8d!KQvJmLxgX95saANMFSeU}' b'sdctf{r0l1-uR~pWn.c6yPtO_wi7h,ECB:I5*b8d!cQvJmLxgX9DsaANMFSeU}' b'sdctf{r0l1-uR~pWn.c6yPtO_wi7h,ECB:I5*b8d!KQvJmLxgX9DsaANMFSeU}' b'sdctf{r0l1-uR~pWn.c6yPtO_wi7h,ECB:I5*b8d!cQvJmLxgX95saANMFSeU}' b'sdctf{r0l1-LR~pWn.c6yPtO_wi7h,ECB:I5*b8d!KQvJmLxgX9DsaANMFSeU}' b'sdctf{r0l1-LR~pWn.c6yPtO_wi7h,ECB:I5*b8d!cQvJmLxgX9DsaANMFSeU}' b'sdctf{r0l1-uR~pWn.c6yPtO_wi7h,ECB:I5*b8d!KQvJmLxgX95saANMFSeU}' b'sdctf{r0l1-LR~pWn.c6yPtO_wi7h,ECB:I5*b8d!cQvJmLxgX95saANMFSeU}'
full script in solve2.py