UMassCTF 2022 - Crypto

UMassCTF is a 48-hour Jeopardy-style Capture The Flag event hosted by the University of Massachusetts Amherst Cybersecurity Club. I played with AmpunBangJago team and we managed to be the only team who solved hatmash, one of the crypto challenges.

hatmash (500 points, 1 solve)

Challenge

What do you mean “We think you spend too much time with matrices.”? It’s just a hash function, jeez…

nc 34.139.216.197 10001

#!/usr/bin/env python3
import os

with open('flag.txt','rb') as f:
    FLAG = f.read()
    f.close()

def bytes_to_mat(x):
    assert len(x) == 32
    bits = list('{:0256b}'.format(int.from_bytes(x,'big')))
    return [[int(j) for j in bits[i:i+16]] for i in range(0,256,16)]

def mat_to_bytes(x):
    return int(''.join([str(i) for j in x for i in j]),2).to_bytes((len(x)*len(x[0])+7)//8,'big')

def mod_mult(a,b,m):
    assert len(a[0]) == len(b)
    return [[sum([a[k][i] * b[i][j] for i in range(len(b))]) % m for j in range(len(a))] for k in range(len(a))]

def mod_add(a,b,m):
    assert len(a[0]) == len(b[0]) and len(a) == len(b)
    return [[(a[i][j] + b[i][j]) % m for j in range(len(a[0]))] for i in range(len(a))]

KEY = os.urandom(32*3)
print('KEY:', KEY.hex())

A,B,C = [bytes_to_mat(KEY[i::3]) for i in range(3)]

def mash(x):
    bits = list('{:0{n}b}'.format(int.from_bytes(x,'big'), n = 8*len(x)))
    if bits.pop(0) == '0':
        ret = A
    else:
        ret = B
    for bit in bits:
        if bit == '0':
            ret = mod_mult(ret, A, 2)
        else:
            ret = mod_mult(ret, B, 2)
    lenC = C
    for _ in range(len(x)):
        lenC = mod_mult(lenC, C, 2)
    return mat_to_bytes(mod_add(ret, lenC, 2))

target_hash = mash(b"gib m3 flag plox?").hex()
print('TARGET:', target_hash)

ALP = range(ord(' '), ord('~'))

try:
    user_msg = input().encode()
    assert all(i in ALP for i in list(user_msg))
    
    if b"gib m3 flag plox?" in user_msg:
        print('Uuh yeah nice try...')

    elif mash(user_msg).hex() == target_hash:
        print('Wow, well I suppose you deserve it {}'.format(FLAG.decode()))

    else:
        print('Not quite, try again...')
    
except:
    print('Try to be serious okay...')

Solution

In this challenge, we must provide an input that causes a hash collision with string gib m3 flag plox? in the matrix-based hash function, in which case the matrices used by the hash function will be different for each connection. Hash collision is possible to occur when all the $ A, B, C $ matrices are invertible. So, there will be $ \mathrm{ord}(A) $ such that $ A^{\mathrm{ord}(A)+1} = A $, where $ \mathrm{ord}(A) $ is the multiplicative order of $ A $ over $ GF(2) $.

sage: A = random_matrix(GF(2), 10, 10); A
[1 0 1 0 1 0 1 1 0 0]
[0 1 0 0 0 0 0 1 0 0]
[0 0 1 0 0 1 1 0 0 1]
[1 1 0 1 0 1 0 1 1 1]
[0 0 1 0 0 1 0 1 1 1]
[1 1 1 0 0 0 0 1 1 1]
[1 0 0 1 0 0 0 0 0 1]
[0 0 0 0 1 1 0 0 0 1]
[1 1 0 1 1 0 1 1 0 0]
[1 1 1 0 0 1 1 1 0 0]
sage: A.is_invertible()
True
sage: ord_A = A.multiplicative_order(); ord_A
372
sage: A^(ord_A+1) == A
True
sage: A*(ord_A+1) == A
True

Let $ \mathrm{ord}(x) $ denotes the multiplicative order of the given matrix formed by character x. The idea is that we will create a payload of string gib m3 flag ploxxx...xxx? such that character x repeats $ \mathrm{ord}(x)+1 $ times in order to make a collision. But here’s the thing, $ C $ protects hash collisions by considering the length of the input. So, in the implementation, we also need to adjust the input length until the payload length satisfies $ \mathrm{len(payload)} = 17 $ modulo $ \mathrm{ord}(C) $. We will use the character that has the smallest multiplicative order ($ \mathrm{ord}(z) $) and the smallest $ k $ such that $ \Delta l + k \times \mathrm{ord}(C) = 0 $ modulo $ \mathrm{ord}(z) $.

Implementation

#!/usr/bin/env python3
from pwn import *
from sage.all import *
from libnum import s2b
import sys

def bytes_to_mat(x):
    assert len(x) == 32
    bits = list('{:0256b}'.format(int.from_bytes(x,'big')))
    return [[int(j) for j in bits[i:i+16]] for i in range(0,256,16)]

def mat_to_bytes(x):
    return int(''.join([str(i) for j in x for i in j]),2).to_bytes((len(x)*len(x[0])+7)//8,'big')

def mod_mult(a,b,m):
    assert len(a[0]) == len(b)
    return [[sum([a[k][i] * b[i][j] for i in range(len(b))]) % m for j in range(len(a))] for k in range(len(a))]

def mod_add(a,b,m):
    assert len(a[0]) == len(b[0]) and len(a) == len(b)
    return [[(a[i][j] + b[i][j]) % m for j in range(len(a[0]))] for i in range(len(a))]

def mash(x):
    bits = list('{:0{n}b}'.format(int.from_bytes(x,'big'), n = 8*len(x)))
    if bits.pop(0) == '0':
        ret = A
    else:
        ret = B
    for bit in bits:
        if bit == '0':
            ret = mod_mult(ret, A, 2)
        else:
            ret = mod_mult(ret, B, 2)
    lenC = C
    for _ in range(len(x)):
        lenC = mod_mult(lenC, C, 2)
    return mat_to_bytes(mod_add(ret, lenC, 2))

ALP = range(ord(' '), ord('~'))

################################################################

if len(sys.argv) > 1:
    r = process(['python3', 'hatmash.py'], level='warn')
else:
    r = remote('34.139.216.197', 10001, level='warn')

r.recvuntil(b'KEY: ')
KEY = bytes.fromhex(r.recvline(0).decode())

r.recvuntil(b'TARGET: ')
TARGET = bytes.fromhex(r.recvline(0).decode())
A, B, C = [bytes_to_mat(KEY[i::3]) for i in range(3)]

################################################################

M_a = matrix(GF(2), A)
M_b = matrix(GF(2), B)
M_c = matrix(GF(2), C)

if M_a.is_invertible() and M_b.is_invertible() and M_c.is_invertible():
    ord_a = M_a.multiplicative_order()
    ord_b = M_b.multiplicative_order()
    ord_c = M_c.multiplicative_order()
    if ord_c > 2048: exit()
    print('\n', ord_a, ord_b, ord_c)
else:
    exit()

################################################################

def find_smallest_order():
    ch = ''
    minorder = 10**7

    for i in ' 3abfgilmopx':
        tmp = bin(ord(i))[2:].zfill(8)
        ret = M_a if tmp[0] == '0' else M_b
        for bit in tmp[1:]:
            if bit == '0':
                ret *= M_a
            else:
                ret *= M_b
        order = ret.multiplicative_order()
        if order < minorder:
            ch = i
            minorder = order

    return [ch, minorder]

ch, minord = find_smallest_order()
print('ch minord', ch, minord)

if minord > 2048: exit()

payload1 = 'gib m3 flag plox?'
payload2 = payload1.replace(ch, ch*(minord+1))
print('len payload2', len(payload2))

################################################################

delta = (len(payload1) - len(payload2)) % ord_c
print('delta', delta)
assert (len(payload2) + delta) % ord_c == len(payload1)

################################################################

def find_ch_order():
    box = []

    for i in range(ord(' '), ord('~')):
        tmp = bin(i)[2:].zfill(8)
        ret = M_a if tmp[0] == '0' else M_b
        for bit in tmp[1:]:
            if bit == '0':
                ret *= M_a
            else:
                ret *= M_b
        order = ret.multiplicative_order()
        box.append([chr(i), order])

    for i in range(ord(' '), ord('~')):
        for j in range(ord(' '), ord('~')):
            # tmp = bin(i)[2:].zfill(8)
            tmp = bin(i)[2:].zfill(8) + bin(j)[2:].zfill(8)
            ret = M_a if tmp[0] == '0' else M_b
            for bit in tmp[1:]:
                if bit == '0':
                    ret *= M_a
                else:
                    ret *= M_b
            order = ret.multiplicative_order()
            box.append([chr(i)+chr(j), order])

    return box

ch_ord = find_ch_order()

found = False
for k in range(ord_c):
    z = delta + k * ord_c
    for ch, oneord in ch_ord:
        if z % oneord == 0 and z % len(ch) == 0:
            print('k', k)
            print('z', z)
            print('ch oneord', ch, oneord)
            found = True
            break
    if found: break

payload2 += ch * (z // len(ch))
print('len payload2', len(payload2))

if len(payload2) > 2048: exit()

################################################################

print(TARGET.hex())
print(mash(payload2.encode()).hex())
r.sendline(payload2.encode())
r.interactive()

Run this code continously until we get all the $ A, B, C $ matrices are invertible and $ C $ has small multiplicative order.

$ while true; do python3 solve.py; done
...
32385 32766 204
ch minord a 126
len payload2 143
delta 78
k 0
z 78
ch oneord "F 39
len payload2 221
b1561dcd7248584b7d38dece7aa4a6cc71a3feff5ccbf0030e46ea4c36e243bd
b1561dcd7248584b7d38dece7aa4a6cc71a3feff5ccbf0030e46ea4c36e243bd
Wow, well I suppose you deserve it UMASS{m4tr1c3s_4r3_dumb_ch4ng3_my_m1nd}

Flag

UMASS{m4tr1c3s_4r3_dumb_ch4ng3_my_m1nd}

Epilogue

This is my first time for being the only person who managed to solve a challenge until the CTF ends at the international CTF. Our team did not perform at its best, but we still managed to secure the top 10 rank out of 314 teams. Kudos to Polymero, the crypto challs author, for serving high-quality crypto challs in this CTF.