UMassCTF 2022 - Crypto

UMassCTF is a 48-hour Jeopardy-style Capture The Flag event hosted by the University of Massachusetts Amherst Cybersecurity Club. I played with AmpunBangJago team and we were the only team to solve hatmash, one of the crypto challenges.

hatmash (500 points, 1 solve)

Challenge

What do you mean “We think you spend too much time with matrices.”? It’s just a hash function, jeez…

nc 34.139.216.197 10001

#!/usr/bin/env python3
import os

with open('flag.txt','rb') as f:
    FLAG = f.read()
    f.close()

def bytes_to_mat(x):
    assert len(x) == 32
    bits = list('{:0256b}'.format(int.from_bytes(x,'big')))
    return [[int(j) for j in bits[i:i+16]] for i in range(0,256,16)]

def mat_to_bytes(x):
    return int(''.join([str(i) for j in x for i in j]),2).to_bytes((len(x)*len(x[0])+7)//8,'big')

def mod_mult(a,b,m):
    assert len(a[0]) == len(b)
    return [[sum([a[k][i] * b[i][j] for i in range(len(b))]) % m for j in range(len(a))] for k in range(len(a))]

def mod_add(a,b,m):
    assert len(a[0]) == len(b[0]) and len(a) == len(b)
    return [[(a[i][j] + b[i][j]) % m for j in range(len(a[0]))] for i in range(len(a))]

KEY = os.urandom(32*3)
print('KEY:', KEY.hex())

A,B,C = [bytes_to_mat(KEY[i::3]) for i in range(3)]

def mash(x):
    bits = list('{:0{n}b}'.format(int.from_bytes(x,'big'), n = 8*len(x)))
    if bits.pop(0) == '0':
        ret = A
    else:
        ret = B
    for bit in bits:
        if bit == '0':
            ret = mod_mult(ret, A, 2)
        else:
            ret = mod_mult(ret, B, 2)
    lenC = C
    for _ in range(len(x)):
        lenC = mod_mult(lenC, C, 2)
    return mat_to_bytes(mod_add(ret, lenC, 2))

target_hash = mash(b"gib m3 flag plox?").hex()
print('TARGET:', target_hash)

ALP = range(ord(' '), ord('~'))

try:
    user_msg = input().encode()
    assert all(i in ALP for i in list(user_msg))
    
    if b"gib m3 flag plox?" in user_msg:
        print('Uuh yeah nice try...')

    elif mash(user_msg).hex() == target_hash:
        print('Wow, well I suppose you deserve it {}'.format(FLAG.decode()))

    else:
        print('Not quite, try again...')
    
except:
    print('Try to be serious okay...')

Solution

We need to find an input that creates a hash collision with the string gib m3 flag plox? in this matrix-based hash function. The matrices used by the hash function change for each connection. A hash collision is possible when all matrices $ A, B, C $ are invertible. In this case, there exists $ \mathrm{ord}(A) $ such that $ A^{\mathrm{ord}(A)+1} = A $, where $ \mathrm{ord}(A) $ is the multiplicative order of $ A $ over $ GF(2) $.

sage: A = random_matrix(GF(2), 10, 10); A
[1 0 1 0 1 0 1 1 0 0]
[0 1 0 0 0 0 0 1 0 0]
[0 0 1 0 0 1 1 0 0 1]
[1 1 0 1 0 1 0 1 1 1]
[0 0 1 0 0 1 0 1 1 1]
[1 1 1 0 0 0 0 1 1 1]
[1 0 0 1 0 0 0 0 0 1]
[0 0 0 0 1 1 0 0 0 1]
[1 1 0 1 1 0 1 1 0 0]
[1 1 1 0 0 1 1 1 0 0]
sage: A.is_invertible()
True
sage: ord_A = A.multiplicative_order(); ord_A
372
sage: A^(ord_A+1) == A
True
sage: A*(ord_A+1) == A
True

Let $ \mathrm{ord}(x) $ be the multiplicative order of the matrix formed by character x. We create a payload string gib m3 flag ploxxx...xxx? where character x repeats $ \mathrm{ord}(x)+1 $ times to create a collision. However, matrix $ C $ prevents hash collisions by considering input length. We need to adjust the input length until the payload length satisfies $ \mathrm{len(payload)} = 17 $ modulo $ \mathrm{ord}(C) $. We use the character with the smallest multiplicative order ($ \mathrm{ord}(z) $) and the smallest $ k $ where $ \Delta l + k \times \mathrm{ord}(C) = 0 $ modulo $ \mathrm{ord}(z) $.

Implementation

#!/usr/bin/env python3
from pwn import *
from sage.all import *
from libnum import s2b
import sys

def bytes_to_mat(x):
    assert len(x) == 32
    bits = list('{:0256b}'.format(int.from_bytes(x,'big')))
    return [[int(j) for j in bits[i:i+16]] for i in range(0,256,16)]

def mat_to_bytes(x):
    return int(''.join([str(i) for j in x for i in j]),2).to_bytes((len(x)*len(x[0])+7)//8,'big')

def mod_mult(a,b,m):
    assert len(a[0]) == len(b)
    return [[sum([a[k][i] * b[i][j] for i in range(len(b))]) % m for j in range(len(a))] for k in range(len(a))]

def mod_add(a,b,m):
    assert len(a[0]) == len(b[0]) and len(a) == len(b)
    return [[(a[i][j] + b[i][j]) % m for j in range(len(a[0]))] for i in range(len(a))]

def mash(x):
    bits = list('{:0{n}b}'.format(int.from_bytes(x,'big'), n = 8*len(x)))
    if bits.pop(0) == '0':
        ret = A
    else:
        ret = B
    for bit in bits:
        if bit == '0':
            ret = mod_mult(ret, A, 2)
        else:
            ret = mod_mult(ret, B, 2)
    lenC = C
    for _ in range(len(x)):
        lenC = mod_mult(lenC, C, 2)
    return mat_to_bytes(mod_add(ret, lenC, 2))

ALP = range(ord(' '), ord('~'))

################################################################

if len(sys.argv) > 1:
    r = process(['python3', 'hatmash.py'], level='warn')
else:
    r = remote('34.139.216.197', 10001, level='warn')

r.recvuntil(b'KEY: ')
KEY = bytes.fromhex(r.recvline(0).decode())

r.recvuntil(b'TARGET: ')
TARGET = bytes.fromhex(r.recvline(0).decode())
A, B, C = [bytes_to_mat(KEY[i::3]) for i in range(3)]

################################################################

M_a = matrix(GF(2), A)
M_b = matrix(GF(2), B)
M_c = matrix(GF(2), C)

if M_a.is_invertible() and M_b.is_invertible() and M_c.is_invertible():
    ord_a = M_a.multiplicative_order()
    ord_b = M_b.multiplicative_order()
    ord_c = M_c.multiplicative_order()
    if ord_c > 2048: exit()
    print('\n', ord_a, ord_b, ord_c)
else:
    exit()

################################################################

def find_smallest_order():
    ch = ''
    minorder = 10**7

    for i in ' 3abfgilmopx':
        tmp = bin(ord(i))[2:].zfill(8)
        ret = M_a if tmp[0] == '0' else M_b
        for bit in tmp[1:]:
            if bit == '0':
                ret *= M_a
            else:
                ret *= M_b
        order = ret.multiplicative_order()
        if order < minorder:
            ch = i
            minorder = order

    return [ch, minorder]

ch, minord = find_smallest_order()
print('ch minord', ch, minord)

if minord > 2048: exit()

payload1 = 'gib m3 flag plox?'
payload2 = payload1.replace(ch, ch*(minord+1))
print('len payload2', len(payload2))

################################################################

delta = (len(payload1) - len(payload2)) % ord_c
print('delta', delta)
assert (len(payload2) + delta) % ord_c == len(payload1)

################################################################

def find_ch_order():
    box = []

    for i in range(ord(' '), ord('~')):
        tmp = bin(i)[2:].zfill(8)
        ret = M_a if tmp[0] == '0' else M_b
        for bit in tmp[1:]:
            if bit == '0':
                ret *= M_a
            else:
                ret *= M_b
        order = ret.multiplicative_order()
        box.append([chr(i), order])

    for i in range(ord(' '), ord('~')):
        for j in range(ord(' '), ord('~')):
            # tmp = bin(i)[2:].zfill(8)
            tmp = bin(i)[2:].zfill(8) + bin(j)[2:].zfill(8)
            ret = M_a if tmp[0] == '0' else M_b
            for bit in tmp[1:]:
                if bit == '0':
                    ret *= M_a
                else:
                    ret *= M_b
            order = ret.multiplicative_order()
            box.append([chr(i)+chr(j), order])

    return box

ch_ord = find_ch_order()

found = False
for k in range(ord_c):
    z = delta + k * ord_c
    for ch, oneord in ch_ord:
        if z % oneord == 0 and z % len(ch) == 0:
            print('k', k)
            print('z', z)
            print('ch oneord', ch, oneord)
            found = True
            break
    if found: break

payload2 += ch * (z // len(ch))
print('len payload2', len(payload2))

if len(payload2) > 2048: exit()

################################################################

print(TARGET.hex())
print(mash(payload2.encode()).hex())
r.sendline(payload2.encode())
r.interactive()

Run this code until we get invertible $ A, B, C $ matrices and $ C $ has a small multiplicative order.

$ while true; do python3 solve.py; done
...
32385 32766 204
ch minord a 126
len payload2 143
delta 78
k 0
z 78
ch oneord "F 39
len payload2 221
b1561dcd7248584b7d38dece7aa4a6cc71a3feff5ccbf0030e46ea4c36e243bd
b1561dcd7248584b7d38dece7aa4a6cc71a3feff5ccbf0030e46ea4c36e243bd
Wow, well I suppose you deserve it UMASS{m4tr1c3s_4r3_dumb_ch4ng3_my_m1nd}

Flag

UMASS{m4tr1c3s_4r3_dumb_ch4ng3_my_m1nd}

Epilogue

This was my first time being the only person to solve a challenge in an international CTF. Our team didn’t perform at our best, but we still secured a top 10 out of 314 teams. Kudos to Polymero, the crypto challenges author, for creating incredible crypto challenges in this CTF.