Black Hat MEA 2022 Qual - Crypto

Black Hat MEA CTF is a CTF competition organized by Black Hat Middle East and Africa in collaboration with Saudi Federation for Cybersecurity, Programming & Drones (SAFCSP). There are 400+ teams participated in this CTF. I participated with AmpunBangJago team and we finished #18 when the competition ended. Here is a write-up of all crypto challenges in this CTF.

Ursa Minor (150 pts, 51 solves)

Challenge

#!/usr/local/bin/python
#
# Polymero
#

# Imports
from Crypto.Util.number import isPrime, getPrime, inverse
import hashlib, time, os

# Local import
FLAG = os.environ.get('FLAG').encode()


class URSA:
    # Upgraded RSA (faster and with cheap key cycling)
    def __init__(self, pbit, lbit):
        p, q = self.prime_gen(pbit, lbit)
        self.public = {'n': p * q, 'e': 0x10001}
        self.private = {'p': p, 'q': q, 'f': (p - 1)*(q - 1), 'd': inverse(self.public['e'], (p - 1)*(q - 1))}
        
    def prime_gen(self, pbit, lbit):
        # Smooth primes are FAST primes ~ !
        while True:
            qlst = [getPrime(lbit) for _ in range(pbit // lbit)]
            if len(qlst) - len(set(qlst)) <= 1:
                continue
            q = 1
            for ql in qlst:
                q *= ql
            Q = 2 * q + 1
            if isPrime(Q):
                break
        while True:
            plst = [getPrime(lbit) for _ in range(pbit // lbit)]
            if len(plst) - len(set(plst)) <= 1:
                continue
            p = 1
            for pl in plst:
                p *= pl
            P = 2 * p + 1
            if isPrime(P):
                break 
        return P, Q
    
    def update_key(self):
        # Prime generation is expensive, so we'll just update d and e instead ^w^
        self.private['d'] ^= int.from_bytes(hashlib.sha512((str(self.private['d']) + str(time.time())).encode()).digest(), 'big')
        self.private['d'] %= self.private['f']
        self.public['e'] = inverse(self.private['d'], self.private['f'])
        
    def encrypt(self, m_int):
        c_lst = []
        while m_int:
            c_lst += [pow(m_int, self.public['e'], self.public['n'])]
            m_int //= self.public['n']
        return c_lst
    
    def decrypt(self, c_int):
        m_lst = []
        while c_int:
            m_lst += [pow(c_int, self.private['d'], self.public['n'])]
            c_int //= self.public['n']
        return m_lst


# Challenge setup
print("""|
|  ~ Welcome to URSA decryption services
|    Press enter to start key generation...""")

input("|")

print("""|
|    Please hold on while we generate your primes...
|\n|""")
    
oracle = URSA(256, 12)
print("|  ~ You are connected to an URSA-256-12 service, public key ::")
print("|    id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest()))
print("|    e  = {}".format(oracle.public['e']))

print("|\n|  ~ Here is a free flag sample, enjoy ::")
for i in oracle.encrypt(int.from_bytes(FLAG, 'big')):
    print("|    {}".format(i))


MENU = """|
|  ~ Menu (key updated after {} requests)::
|    [E]ncrypt
|    [D]ecrypt
|    [U]pdate key
|    [Q]uit
|"""

# Server loop
CYCLE = 0
while True:
    
    try:

        if CYCLE % 4:
            print(MENU.format(4 - CYCLE))
            choice = input("|  > ")

        else:
            choice = 'u'
        
        if choice.lower() == 'e':
            msg = int(input("|\n|  > (int) "))

            print("|\n|  ~ Encryption ::")
            for i in oracle.encrypt(msg):
                print("|    {}".format(i))

        elif choice.lower() == 'd':
            cip = int(input("|\n|  > (int) "))

            print("|\n|  ~ Decryption ::")
            for i in oracle.decrypt(cip):
                print("|    {}".format(i))
            
        elif choice.lower() == 'u':
            oracle.update_key()
            print("|\n|  ~ Key updated succesfully ::")
            print("|    id = {}".format(hashlib.sha256(str(oracle.public['n']).encode()).hexdigest()))
            print("|    e  = {}".format(oracle.public['e']))

            CYCLE = 0
            
        elif choice.lower() == 'q':
            print("|\n|  ~ Closing services...\n|")
            break
            
        else:
            print("|\n|  ~ ERROR - Unknown command")

        CYCLE += 1
        
    except KeyboardInterrupt:
        print("\n|  ~ Closing services...\n|")
        break
        
    except:
        print("|\n|  ~ Please do NOT abuse our services.\n|")

TL;DR

  • Modulus $ n $ is hidden, but the SHA256 of modulus $ n $ is known
  • Flag is encrypted using pubkey $ e = 65537 $
  • The prime size $ p $ and $ q $ used are small, can be factored with factordb
  • The pair of $ e $ and $ d $ keys are updated right after the flag is encrypted and will update automatically every 4 cycles

Solution

We can take advantage of the encryption option in the service to get the modulus $ n $ back in a similar way to the case of NSUCRYPTO 2020: Hidden RSA. The idea is that we can get the modulus $ n $ back from $ \mathrm{enc}(2) $, $ \mathrm{enc}(4) $, and $ \mathrm{enc}(8) $. These equations:

  • $ 2^{2e} - 2^{2e} = 0 \pmod n $
  • $ 2^{2e} - 4^{e} = 0 \pmod n $
  • $ 2^{e} 2^{e} - 4^{e} = 0 \pmod n $
  • $ 2^{e} 2^{e} - 4^{e} = k_{1} n $

and

  • $ 2^{3e} - 2^{3e} = 0 \pmod n $
  • $ 2^{e} 2^{2e} - 2^{3e} = 0 \pmod n $
  • $ 2^{e} 4^{e} - 8^{e} = 0 \pmod n $
  • $ 2^{e} 4^{e} - 8^{e} = k_{2} n $

can recover modulus $ n $ back by calculating $ n = \mathrm{GCD}(k_{1} n, k_{2} n) $. After $ n $ is recovered, factorize $ n $ with factordb, then the next step is quite trivial.

Implementation

from deom import *
import hashlib

HOST = 'blackhat2-5363975a5ad0a94a2e81ff6ea55e972a-0.chals.bh.ctf.sa'

# io = process('python3 ursaminor.py'.split())
io = remote(HOST, 443, ssl=True, sni=HOST)

def get_initial_values():
    io.sendlineafter(b'generation...\n|', b'')

    io.recvuntil(b'id = ')
    H_n = io.recvline(0).decode()

    io.recvuntil(b'e  = ')
    e = io.recvline(0).decode()

    io.recvuntil(b'enjoy ::\n|    ')
    c = io.recvline(0).decode()

    return [H_n, int(e), int(c)]

def send_encrypt(m):
    io.sendlineafter(b'|  > ', b'e')
    io.sendlineafter(b'|  > (int) ', str(m).encode())
    io.recvuntil(b'Encryption ::\n|    ')
    res = io.recvline(0).decode()
    return int(res)

H_n, e, c = get_initial_values()

c2 = send_encrypt(2)
c4 = send_encrypt(4)
c8 = send_encrypt(8)

k1n = c2 * c2 - c4
k2n = c2 * c4 - c8
n = gcd(k1n, k2n)

for i in range(2, 1000+2):
    while n % i == 0:
        n //= i
print(f'{n = }')
print(hashlib.sha256(str(n).encode()).hexdigest() == H_n)
print(f'{c = }')

io.close()
n = 1901677835465328762525045529900644508369351221978409915488480263646618951530054167614329569744978490699403857261498216571695214295951453028840101801
True
c = 797581560322388898698318083907163541431496361921829824150624839779318880101350859625599903033391266502361090089553828809117706944973035239443998532
from deom import *

n = 1901677835465328762525045529900644508369351221978409915488480263646618951530054167614329569744978490699403857261498216571695214295951453028840101801
p = 37412853947783363258275868432695264738275815525630275165612551890384492447
q = 50829531425736083666184827746380554905641232404537597577785718063669472183
e = 65537
d = inverse(e, (p-1) * (q-1))
c = 797581560322388898698318083907163541431496361921829824150624839779318880101350859625599903033391266502361090089553828809117706944973035239443998532
print(n2s(pow(c, d, n)))

Flag

BlackHatMEA{471:19:354a4d73bee389d4128a1f96fc5514208fcdb637}

Nothing Up My Sbox (250 pts, 16 solves)

Challenge

#!/usr/local/bin/python
#
# Polymero
#

# Imports
import os, time
from secrets import randbelow
from hashlib import sha256

# Local imports
FLAG = os.environ.get('FLAG').encode()


class NUMSBOX:
    def __init__(self, seed, key):
        self.sbox = self.gen_box('SBOX :: ' + seed)
        self.pbox = self.gen_box(str(time.time()))
        self.key = key

    def gen_box(self, seed):
        box = []
        i = 0
        while len(box) < 16:
            i += 1
            h = sha256(seed.encode() + i.to_bytes(2, 'big')).hexdigest()
            for j in h:
                b = int(j, 16)
                if b not in box:
                    box += [b]
        return box
    
    def subs(self, x):
        return [self.sbox[i] for i in x]
    
    def perm(self, x):
        return [x[i] for i in self.pbox]
    
    def kxor(self, x, k):
        return [i ^ j for i,j in zip(x, k)]
    
    def encrypt(self, msg):
        if len(msg) % 16:
            msg += (16 - (len(msg) % 16)) * [16 - (len(msg) % 16)]
        blocks = [msg[i:i+16] for i in range(0, len(msg), 16)]
        cip = []
        for b in blocks:
            x = self.kxor(b, self.key)
            for _ in range(4):
                x = self.subs(x)
                x = self.perm(x)
                x = self.kxor(x, self.key)
            cip += x
        return ''.join([hex(i)[2:] for i in cip])
    
    
KEY = [randbelow(16) for _ in range(16)]

OTP = b""
while len(OTP) < len(FLAG):
    OTP += sha256(b" :: ".join([b"OTP", str(KEY).encode(), len(OTP).to_bytes(2, 'big')])).digest()
    
encflag = bytes([i ^ j for i,j in zip(FLAG, OTP)]).hex()

print("|\n|  ~ In order to prove that I have nothing up my sleeve, I let you decide on the sbox!")
print("|    I am so confident, I will even stake my flag on it ::")
print("|    flag = {}".format(encflag))

print("|\n|  ~ Now, player, what should I call you?")
seed = input("|\n|  > ")

oracle = NUMSBOX(seed, KEY)

print("|\n|  ~ Well {}, here are your s- and p-box ::".format(seed))
print("|    s-box = {}".format(oracle.sbox))
print("|    p-box = {}".format(oracle.pbox))


MENU = """|
|  ~ Menu ::
|    [E]ncrypt
|    [Q]uit
|"""

while True:

    try:

        print(MENU)
        choice = input("|  > ")

        if choice.lower() == 'e':
            msg = [int(i, 16) for i in input("|\n|  > (hex) ")]
            print("|\n|  ~ {}".format(oracle.encrypt(msg)))

        elif choice.lower() == 'q':
            print("|\n|  ~ Sweeping the boxes back up my sleeve...\n|")
            break

        else:
            print("|\n|  ~ Sorry I do not know what you mean...")

    except KeyboardInterrupt:
        print("\n|  ~ Sweeping the boxes back up my sleeve...\n|")
        break

    except:
        print("|\n|  ~ Hey, be nice to my code, okay?")

TL;DR

  • Flag is encrypted with OTP xor, which is generated from KEY
  • KEY is a list containing 16 numbers in the range 0-15 which is generated randomly
  • KEY is also used when creating oracle object from NUMSBOX class
  • We can encrypt any message with oracle
  • oracle encrypts the plaintext by splitting the plaintext into 16-byte blocks, then encrypts each block with 4 rounds of sbox substitution, pbox permutation, and KEY xor
  • sbox and pbox are lists containing 16 numbers in the range 0-15 which are randomly generated with seeds
  • sbox seed is affected by user input, and pbox seed is affected by current time

Solution

The pbox permutation causes a shuffling of the character position in the plaintext, so I made a mapping of it with the aim of making it easier to work with. I also made a replica of oracle object using the same sbox and pbox, but with a custom KEY. This is done to find any position in the plaintext that is affected by a char position in the key. In most cases, 1 char key affects 4-5 chars of plaintext. Since the possible chars are only in the range 0-15, bruteforcing 4-5 chars is still very possible because the number of possibilities is only 16**5 or 20-bit.

The pbox seed is affected by current time, so we need to bruteforce the connection to the service until we get a “good” pbox, or in this case, until get the set_box length and the KEY probability are quite small (see the implementation below).

Implementation

from deom import *
import hashlib
import itertools
import time

HOST = 'blackhat4-1f456da30f47d507a718e01fbc35a3bc-0.chals.bh.ctf.sa'

# io = process('python3 nothingupmysbox.py'.split(), level='warn')
io = remote(HOST, 443, ssl=True, sni=HOST, level='warn')

class NUMSBOX:
    def __init__(self, sbox, pbox, key):
        self.sbox = sbox
        self.pbox = pbox
        self.key = key

    def subs(self, x):
        return [self.sbox[i] for i in x]
    
    def perm(self, x):
        return [x[i] for i in self.pbox]

    def inv_subs(self, x):
        return [self.sbox.index(i) for i in x]

    def inv_perm(self, x):
        return [x[pbox.index(i)] for i in range(16)]

    def kxor(self, x, k):
        return [i ^ j for i,j in zip(x, k)]
    
    def encrypt(self, msg, round=4):
        if len(msg) % 16:
            msg += (16 - (len(msg) % 16)) * [16 - (len(msg) % 16)]
        blocks = [msg[i:i+16] for i in range(0, len(msg), 16)]
        cip = []
        for b in blocks:
            x = self.kxor(b, self.key)
            for _ in range(round):
                # print('x1', x)
                x = self.subs(x)
                # print('x2', x)
                x = self.perm(x)
                # print('x3', x)
                x = self.kxor(x, self.key)
                # print('x4', x)
            cip += x
        return ''.join([hex(i)[2:] for i in cip])

    def decrypt(self, msg, round):
        blocks = [msg[i:i+16] for i in range(0, len(msg), 16)]
        cip = []
        for b in blocks:
            x = self.kxor(b, self.key)
            for _ in range(round):
                x = self.inv_perm(x)
                x = self.inv_subs(x)
                x = self.kxor(x, self.key)
            cip += x
        return ''.join([hex(i)[2:] for i in cip])

def arr2str(arr):
    return ''.join([hex(i)[2:] for i in arr])

def str2arr(s):
    return [int(x, 16) for x in s]

io.recvuntil(b'flag = ')
enc_flag = io.recvline(0).decode()
print('enc_flag', enc_flag)

io.sendlineafter(b'> ', b'deom')

io.recvuntil(b's-box = ')
sbox = eval(io.recvline(0).decode())

io.recvuntil(b'p-box = ')
pbox = eval(io.recvline(0).decode())

print(sbox)
print(pbox)

def send_encrypt(pt):
    io.sendlineafter(b'> ', b'e')
    io.sendlineafter(b'> (hex) ', str(pt).encode())
    io.recvuntil(b'~ ')
    ct = io.recvline(0).decode()
    return ct

################################################################

super_box = []

def find_origin_index(idx):
    ori_idx = []
    x0 = str2arr('0000000000000000')

    for i in range(16):
        tq = '0'*(i) + '1' + '0'*(15-i)
        tq = str2arr(tq)
        nbt = NUMSBOX(sbox, pbox, tq)
        y = nbt.encrypt(x0, round=4)
        super_box.append([tq, y, i])

        if y[idx] != 'b':
            ori_idx.append(i)

    tmp = send_encrypt(arr2str(x0))
    target = [tmp[i] for i in ori_idx]
    return [ori_idx, target]

ori_idx_box = []

for i in range(16):
    res = find_origin_index(i)
    if res not in ori_idx_box:
        ori_idx_box.append(res)

for box in ori_idx_box:
    print('>>>', box[0], box[1])

################################################################

def bf_small_key(ori_idx, target):
    LEN = len(ori_idx)
    poss = list(itertools.product(*[list(range(16)) for _ in range(LEN)]))
    x0 = str2arr('0000000000000000')

    for pos in poss:
        tq = [0 for _ in range(16)]
        for i in range(LEN):
            tq[ori_idx[i]] = pos[i]

        nb = NUMSBOX(sbox, pbox, tq)
        y = nb.encrypt(x0, round=4)

        check = True
        for i in range(LEN):
            check &= y[ori_idx[i]] == target[i]
        if check:
            print('tq', tq)

set_box = []

for box in ori_idx_box:
    if len(box[0]) <= 4:
        print('box', box)
        bf_small_key(box[0], box[1])
    else:
        set_box += box[0]

set_box = set(set_box)
print('set_box', len(set_box), '->', list(set_box))
print()

################################################################

io.close()
time.sleep(1)

# while true; do python3 solve.py; done
...

enc_flag bcbf00f75a100ace7d1282fe857693ed1fa7c4e4957e22790b916784d791d0ec1a62f629dfa380c31d3e4cf4f846f660ad7da90607de7838fb4b1273
[7, 11, 2, 15, 9, 14, 8, 12, 5, 4, 13, 3, 1, 10, 0, 6]
[1, 3, 13, 9, 7, 2, 6, 11, 10, 15, 14, 12, 4, 5, 8, 0]
>>> [0, 1, 3, 9, 15] ['c', '0', '0', 'e', 'f']
>>> [2, 5, 13] ['5', 'f', '6']
>>> [4, 7, 11, 12] ['e', '9', '3', '2']
>>> [6] ['6']
>>> [8, 10, 14] ['a', 'f', 'e']
box [[2, 5, 13], ['5', 'f', '6']]
tq [0, 0, 10, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
tq [0, 0, 14, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
box [[4, 7, 11, 12], ['e', '9', '3', '2']]
tq [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 5, 10, 0, 0, 0]
tq [0, 0, 0, 0, 10, 0, 0, 14, 0, 0, 0, 5, 5, 0, 0, 0]
box [[6], ['6']]
tq [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
tq [0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]
box [[8, 10, 14], ['a', 'f', 'e']]
tq [0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 10, 0, 0, 0, 8, 0]
tq [0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 10, 0, 0, 0, 7, 0]
tq [0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 3, 0, 0, 0, 6, 0]
set_box 5 -> [0, 1, 3, 9, 15]

...

Finally, we got a small set_box. Copy the output above to the code below.

from functools import reduce
from hashlib import sha256
from tqdm import tqdm
import itertools

def vec_add(x, y):
    res = [0 for _ in range(len(x))]
    for i in range(len(x)):
        res[i] = x[i] + y[i]
    return res

set_box = [0, 1, 3, 9, 15]
poss = list(itertools.product(*[list(range(16)) for _ in range(len(set_box))]))
enc = bytes.fromhex('bcbf00f75a100ace7d1282fe857693ed1fa7c4e4957e22790b916784d791d0ec1a62f629dfa380c31d3e4cf4f846f660ad7da90607de7838fb4b1273')

box = [
    [
        [0, 0, 10, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 14, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ],
    [
        [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 5, 10, 0, 0, 0],
        [0, 0, 0, 0, 10, 0, 0, 14, 0, 0, 0, 5, 5, 0, 0, 0],
    ],
    [
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ],
    [
        [0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 10, 0, 0, 0, 8, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 10, 0, 0, 0, 7, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 3, 0, 0, 0, 6, 0],
    ],
]

xs = list(itertools.product(*box))
arrs = [reduce(vec_add, list(x)) for x in xs]

for arr in arrs:
    print('arr', arr)
    for pos in tqdm(poss):
        tmp = arr
        for i, j in zip(set_box, pos):
            tmp[i] = j
        KEY = tmp
        OTP = b""
        while len(OTP) < len(enc):
            OTP += sha256(b" :: ".join([b"OTP", str(KEY).encode(), len(OTP).to_bytes(2, 'big')])).digest()
            
        flag = bytes([i ^ j for i,j in zip(enc, OTP)])
        if b'BlackHatMEA' in flag:
            print(flag.decode())
            exit()

Flag

BlackHatMEA{471:20:fcef0de6edab173edf977a452a6eef14c98e2253}