Skip to main content

[HITCON 2023] Careless Padding write-up

This is a writeup for the Careless Padding challenge of the HITCON 2023 CTF. We analyse a vulnerable usage of AES in CBC mode, and recover the encrypted secret.

Presentation of the challenge

This challenge was the second most solved challenge of all crypto challenges of the HITCON 2023 CTF. Thanks to bronson113 for the challenge.

Before downloading the challenge, the following text is prompted:

How careless can you be as an assistant…

Understanding the challenge

The following code is running on the server.

#!/usr/local/bin/python
import random
import os
from secret import flag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import json

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
    block_count = (length-1) // N + 1
    return block_count

def find_repeat_tail(message):
    Y = message[-1]
    message_len = len(message)
    for i in range(len(message)-1, -1, -1):
        if message[i] != Y:
            X = message[i]
            message_len = i + 1
            break
    return message_len, X, Y

def my_padding(message):
    message_len = len(message)
    block_count = count_blocks(message_len)
    result_len =  block_count * N
    if message_len % N == 0:
        result_len += N
    X = message[-1]
    Y = message[(block_count-2)*N+(X%N)]
    if X==Y:
        Y = Y^1
    padded = message.ljust(result_len, bytes([Y]))
    return padded

def my_unpad(message):
    message_len, X, Y = find_repeat_tail(message)
    block_count = count_blocks(message_len)
    _Y = message[(block_count-2)*N+(X%N)]
    if (Y != _Y and Y != _Y^1):
        raise ValueError("Incorrect Padding")
    return message[:message_len]

def chal():
    k = os.urandom(16)
    m = json.dumps({'key':flag}).encode()

    iv = os.urandom(16)
    cipher = AES.new(k, AES.MODE_CBC, iv)

    padded = my_padding(m)
    enc = cipher.encrypt(padded)
    print(f"""
*********************************************************
You are put into the careless prison and trying to escape.
Thanksfully, someone forged a key for you, but seems like it's encrypted... 
Fortunately they also leave you a copied (and apparently alive) prison door.
The replica pairs with this encrypted key. Wait, how are this suppose to help?
Anyway, here's your encrypted key: {(iv+enc).hex()}
*********************************************************
""")

    while True:
        enc = input("Try unlock:")
        enc = bytes.fromhex(enc)
        iv = enc[:16]
        cipher = AES.new(k, AES.MODE_CBC, iv)
        try:
            message = my_unpad(cipher.decrypt(enc[16:]))
            if message == m:
                print("Hey you unlock me! At least you know how to use the key")
            else:
                print("Bad key... do you even try?")
        except ValueError:
            print("Don't put that weirdo in me!")
        except Exception:
            print("What? Are you trying to unlock me with a lock pick?")

if __name__ == "__main__":
    chal()

def chal

Looking at the chal function, we can see that the program generates a random 16 bytes key k and a 16 bytes initialization vector iv, and initialize an instance a an AES cipher cipher with these parameters. It then pads the bytes of a json containing the flag and encrypt the paded result with the previous AES cipher.
We are then provided the IV and the ciphertext, and we need to find the flag. To do so, we can provided a message to the server. This message consists in a 16 bytes IV, and the rest is a ciphertext of a padded message.
The server decrypts our input and unpad it, before comparing if to the initial message. The server can then return 4 different outputs:

  • If the decryption and unpadding went smoothly:
    • If the decrypted message is the same as the one containing the flag, the server prints "Hey you unlock me! At least you know how to use the key"
    • If the decrypted message is different, the server prints "Bad key... do you even try?".
  • If we encounter a ValueError, it returns "Don't put that weirdo in me!".
  • Another Exception leads to a printed value of "What? Are you trying to unlock me with a lock pick?".

Given the name of the exercise, the vulnerability will certainly be in the padding (resp. unpadding) function my_padding (resp. my_unpad). my_unpad wil lcertainly raise an Exception in some cases, and we will be able to use this information to gain knowledge about the key.

def my_padding

Let’s get started and dive into these functions. Let’s start with the padding function my_padding.

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
    block_count = (length-1) // N + 1
    return block_count

def my_padding(message):
    message_len = len(message)
    block_count = count_blocks(message_len)
    result_len =  block_count * N
    if message_len % N == 0:
        result_len += N
    X = message[-1]
    Y = message[(block_count-2)*N+(X%N)]
    if X==Y:
        Y = Y^1
    padded = message.ljust(result_len, bytes([Y]))
    return padded

We start by counting how many blocks of N (16) bytes would be required to enncrypt the message entirely. The resulting length of the ciphertext will then be size of a block * number of blocks. Because we eventually need to pad the message when it does not perfectly fit in the blocks, we need to add a block when the message perfectly fit in the block. Otherwise we would not be able to tell if the message perfectly fits in block_count blocks or if it has been padded to fit.
We then take the last byte of the message X. We use it to index a byte Y in the penultimate block. We use this byte Y to right-pad the message so it fits in the resulting length computed before. This means the last block will look like: <beginning of the block>XYYYYYYYY. If X and Y are equal, we flip the last bit of Y to keep them disctinct.

def my_unpad

Now that we understood how the padding works, let’s see how unpadding is done.

def find_repeat_tail(message):
    Y = message[-1]
    message_len = len(message)
    for i in range(len(message)-1, -1, -1):
        if message[i] != Y:
            X = message[i]
            message_len = i + 1
            break
    return message_len, X, Y

def my_unpad(message):
    message_len, X, Y = find_repeat_tail(message)
    block_count = count_blocks(message_len)
    _Y = message[(block_count-2)*N+(X%N)]
    if (Y != _Y and Y != _Y^1):
        raise ValueError("Incorrect Padding")
    return message[:message_len]

This is basically my_padding, but backwards, as one could expect. In find_repeat_tail, we start by reading the message backward, we keep track of the last byte Y, and we read the message byte by byte, as long as the read byte is equal to the last byte. When it changes, we note the index we reached, and use it to compute the length of the padding. The value we just read is X, the last byte of the message.

We then verify that the value of Y is indeed the value of the byte indexed thanks to X (or with the last bit flipped if they are equal). If they are not equal, we raise a ValueError. Nice, we will certainly be able to use this to gain knowledge about those bytes.

Finding the vulnerability

So, the information that we have is whether the message has been correctly unpadded or not. We also actually know the first 16 bytes of the plaintext.

>>> print(json.dumps({'key':'hitcon{test}'})[:16])
{"key": "hitcon{

How lucky we are…

Reducing the field of possibilities

To simplify our example, let’s assume that the message only has two blocks. Then, we will be able to recursively extend the process to any number N of blocks.
We provide to the server (our oracle), the following 48 bytes:

IV + first_block_ciphertext + second_block_ciphertext

After decrypting the ciphertext, the oracle will analyse the padding. To do so, it will take the value Y of the last byte of the second block, and iterate backwards until it finds a byte with a different value X. Then, it will check that the byte at index X mod 16 in the first block is equal to Y or Y^1. Right now, the padding will certainly be invalid, because we truncated the input, and the end of the second block is not a real padding.
But, if we iterate over every byte of our plaintext in the first block, and we try every possible byte, at some point the padding verification will pass, and we will know that we found the index X mod 16, and Y or Y^1, i.e. the value of the last byte of the second block ! Obviously, we cannot iterate over the plaintext because we don’t know how this would impact the ciphertext. But we can do the same iteration over the IV, because it is directly XORed with the result of the first block cipher decryption to give the plaintext. So flipping a bit in the IV will also flip the same bit in the resulting plaintext (for a given ciphertext).

The iteration gives us two possibilities for the last byte and 16 for the penultimate byte. In fact, we know the format of a flag, and we know that each byte must represent an ASCII character, so this is certainly less than 16 possible bytes for the penultimate byte.

Bruteforcing the right candidates

It would be great if we could decide what the real bytes actually are.

For reminder, decryption using AES CBC works as follow:<sodipodi:namedview id=“base” pagecolor="#ffffff" bordercolor="#666666" borderopacity=“1.0” inkscape:pageopacity=“0.0” inkscape:pageshadow=“2” inkscape:zoom=“2” inkscape:cx=“349.78805” inkscape:cy=“102.66177” inkscape:document-units=“px” inkscape:current-layer=“layer2” showgrid=“true” inkscape:window-width=“1918” inkscape:window-height=“1056” inkscape:window-x=“0” inkscape:window-y=“22” inkscape:window-maximized=“0” fit-margin-top=“0” fit-margin-left=“0” fit-margin-right=“0” fit-margin-bottom=“0”> <inkscape:grid type=“xygrid” id=“grid3175” empspacing=“5” visible=“true” enabled=“true” snapvisiblegridlinesonly=“true” originx=“0.5px” originy=“0.5px”/> </sodipodi:namedview>Cipher Block Chaining (CBC) mode decryptionblock cipherdecryptionKeyPlaintextCiphertextInitialization Vector (IV)block cipherdecryptionKeyPlaintextCiphertextblock cipherdecryptionKeyPlaintextCiphertext

As you can see, modifying a byte in the penultimate block of the ciphertext has a direct impact on the value at the same index in the last block. This means that we can toy with the padding by modifying the penultimate block of the ciphertext.
We said earlier that we had 2 possibilities for Y, and 16 for X (let’s forget the ASCII simplification for a minute). This gives us 32 possibilities for the (X,Y) pair. How could we bruteforce them to find the right one ?
Well, let’s take a pair $(X_0,Y_0)$ and test if this is the right one. To do so, we can modify the last two bytes of the first block ciphertext, so that the last two bytes of the plaintext of the last block are equal to $X_0$.
The penultimate byte of the plaintext is already set to $X_0$, this is our base assumption, so no need to touch it. To set the last byte to $Y_0$, we need to XOR the last byte of the ciphertext of the first block with $Y_0$, this will set the byte of the plaintext to 0, then with $X_0$.
Then we iterate on the IV again to find a potential X mod 16 (which is the anti-penultimate byte of the last block). That’s great, but we have not checked anything yet. Maybe we guessed $(X_0,Y_0)$ wrong, the last bytes of the plaintext are not set to $X_0$ at all, and our X is completely wrong.
To make sure that this does not happen, we will do a second test, and try to set both bytes to $Y_0$ this time. If we find a different X mod 16, this means that the last two bytes are not equal, and thus, we have not correctly guessed $(X_0,Y_0)$, and we need to test another pair.
Note that we may never find a pair that works well. Indeed, if the byte just before the one we are testing is equal to the last byte of the block, or to the byte we are testing, the padding will be shifted by one. Hopefully, this cases are rare enough, and we can bruteforce our way out of these.

Repeating the process

Once we found a pair that works well, we can continue our process, and we have again 16 possibilities for our new X, which is the antipenultimate byte of the last block. We can check which one it actually is by using the same technique. We try to modify the ciphertext of the first block so that all the bytes at the end of the second block are set to our guessed X. We can iterate over all the IV to find our potential next X.
Then we do the same thing, except that we set all of our bytes to a known value of the end of the second plaintext block (the last one for example). We iterate again on all the IV to find a candidate X. If they are the same, that’s good news. Otherwise, it means that our initial guess for X is wrong.

We can then proceed block by block and extract the flag. The last block being incomplete in terms of plaintext, it raises a few issue. We don’t know how much characters are plaintext in the block, and how much corresponds to padding. But what we know is that the plaintext ends with }"} because of the json. So X is }, which as the UTF-8 code 125. So Y will be at index 125 % 16 = 13 in the previous block. The padding byte is then chr(8).

Given this information and a bit of bruteforcing, we can recover the complete flag.

Code

Here is the complete code:

from pwn import remote
import json
from time import sleep

BLOCK_SIZE = 16
known_blocks = [b'{"key": "hitcon{'[:16]]
#until_now = b'{"key": "hitcon{p4dd1ng_w0n7_s4v3_y0u_Fr0m_4_0rac13_617aa68c06d7ab91f57d1969e8e8'
#known_blocks = [until_now[i*16:(i+1)*16] for i in range(5)]

def get_X_and_Y_IV(soc, base_iv, base_cipher_block_1, base_cipher_block_2):
    X = None
    Y_iv = None
    for possible_X in range(BLOCK_SIZE):
        ivs = []
        for iv_value in range(256):
            iv = base_iv[:possible_X] + bytes([iv_value]) + base_iv[(possible_X+1):]
            assert(len(iv) == BLOCK_SIZE)
            ivs.append(iv)

        payload = [(iv + base_cipher_block_1 + base_cipher_block_2).hex() for iv in ivs]
        assert(len(payload[0]) == 2 * 3 * BLOCK_SIZE)
        soc.send(('\n'.join(payload) + '\n').encode())
        for i in range(256):
            line = soc.recvline()
            if not b'weirdo' in line:
                X = possible_X
                Y_iv = i
        if X is not None:
            return X, Y_iv
    raise ValueError("Could not find potential X")
    
soc = remote("localhost", 11111, level='CRITICAL')
soc.recvuntil(b"Anyway, here's your encrypted key: ")
data = bytes.fromhex(soc.recvline().strip().decode())
blocks_number = len(data) // BLOCK_SIZE + (len(data) % BLOCK_SIZE != 0)
soc.close()

def reset(block_to_guess_index):
    soc = remote("localhost", 11111, level='CRITICAL')
    soc.recvuntil(b"Anyway, here's your encrypted key: ")
    data = bytes.fromhex(soc.recvline().strip().decode())
    soc.recvline() # *********
    soc.recvline() # empty line
    base_iv = data[BLOCK_SIZE*(block_to_guess_index-1):BLOCK_SIZE*block_to_guess_index]
    base_cipher_blocks = [data[BLOCK_SIZE*i:BLOCK_SIZE*(i+1)] for i in range(block_to_guess_index, len(data) // BLOCK_SIZE)]
    first_cipher_block = base_cipher_blocks[0]
    second_cipher_block = base_cipher_blocks[1]
    return soc, base_iv, first_cipher_block, second_cipher_block

for block_to_guess_index in range(1, blocks_number):
#for block_to_guess_index in range(5, blocks_number):

    # Reset socket so we have time to get another byte before timeout
    soc, base_iv, first_cipher_block, second_cipher_block = reset(block_to_guess_index)
    current_guessing = [0]*len(second_cipher_block) # Guessing of current block
    # Get the potential last two values for the block
    Xmod16, Y_iv = get_X_and_Y_IV(
        soc, 
        base_iv,
        base_cipher_block_1=first_cipher_block, 
        base_cipher_block_2=second_cipher_block
    )
    Y = known_blocks[-1][Xmod16] ^ Y_iv ^ base_iv[Xmod16]
    possibilities = [(Xmod16 + 16*i, Y) for i in range(16)] + [(Xmod16 + 16*i, Y^1) for i in range(16)]
    for possibility in possibilities:
        if not (48 <= possibility[0] <= 126) or not (48 <= possibility[1] <= 126):
            continue
        print(f"Trying {chr(possibility[0])}{chr(possibility[1])}")
        potentialX_1, _ = get_X_and_Y_IV(
            soc,
            base_iv,
            base_cipher_block_1=first_cipher_block[:-1] + bytes([possibility[0] ^ first_cipher_block[-1] ^ possibility[1]]),
            base_cipher_block_2=second_cipher_block
        )
        potentialX_2, _ = get_X_and_Y_IV(
            soc,
            base_iv,
            base_cipher_block_1=first_cipher_block[:-2] + bytes([possibility[0] ^ first_cipher_block[-2] ^ possibility[1]]) + first_cipher_block[-1:],
            base_cipher_block_2=second_cipher_block
        )
        if potentialX_1 == potentialX_2:
            print(f"Couple {possibility} looks good with X={potentialX_1}")
            current_guessing[-1] = possibility[1]
            current_guessing[-2] = possibility[0]
            potential_X = potentialX_1
            break
    else:
        if block_to_guess_index == 4:
            current_guessing[-1] = ord('8')
            current_guessing[-2] = ord('e')
            potential_X = 8
        elif block_to_guess_index == 5:
            current_guessing[-1] = ord('8')
            current_guessing[-2] = ord('8')
            potential_X = ord('}') % 16 
        else:
            raise ValueError(f"Could not find valid possibility in {possibilities}")
    
    # Now that we know the last two bytes, continue to leak bytes one by one
    for byte_index_to_guess in range(len(second_cipher_block) - 3, -1, -1):
        # Reset socket so we have time to get another byte before timeout
        soc, base_iv, first_cipher_block, second_cipher_block = reset(block_to_guess_index)
        
        possibilities = [potential_X + 16*i for i in range(16)]
        for possibility in possibilities:
            if not (48 <= possibility <= 126):
                continue
            current_guessing[byte_index_to_guess] = possibility
            # Set all last bytes to the potential value of X that we try to confirm
            block = bytes([possibility ^ first_cipher_block[i] ^ current_guessing[i] for i in range(byte_index_to_guess+1, BLOCK_SIZE, 1)])
            potentialX_1, _ = get_X_and_Y_IV(
                soc,
                base_iv,
                base_cipher_block_1=first_cipher_block[:byte_index_to_guess+1] + block,
                base_cipher_block_2=second_cipher_block
            )
            block = bytes([current_guessing[-1] ^ first_cipher_block[i] ^ current_guessing[i] for i in range(byte_index_to_guess, BLOCK_SIZE-1, 1)])
            potentialX_2, _ = get_X_and_Y_IV(
                soc,
                base_iv,
                base_cipher_block_1=first_cipher_block[:byte_index_to_guess] + block + bytes([first_cipher_block[-1]]),
                base_cipher_block_2=second_cipher_block
            )
            block = bytes([current_guessing[5] ^ first_cipher_block[i] ^ current_guessing[i] for i in range(byte_index_to_guess, BLOCK_SIZE, 1)])
            potentialX_3, _ = get_X_and_Y_IV(
                soc,
                base_iv,
                base_cipher_block_1=first_cipher_block[:byte_index_to_guess] + block,
                base_cipher_block_2=second_cipher_block
            )
            if potentialX_1 == potentialX_2 and potentialX_1 == potentialX_3:
                print(f"Possibility {possibility} looks good with X={potentialX_1}")
                current_guessing[byte_index_to_guess] = possibility
                potential_X = potentialX_1
                break
        else:
            # If we did not find a valid candidate, 
            # this means that the byte just before the one we are trying to is the same as the on we are trying to guess.
            # Or the same as the last one of the block
            print(f'Whoops, two bytes are equals {block_to_guess_index}')
            if block_to_guess_index == 3:
                if byte_index_to_guess == 8:
                    current_guessing[byte_index_to_guess] = ord('a')
                    potential_X = 1
                elif byte_index_to_guess == 7:
                    current_guessing[byte_index_to_guess] = ord('a')
                    potential_X = 7
                elif byte_index_to_guess == 6:
                    current_guessing[byte_index_to_guess] = ord('7')
                    potential_X = 1
                else:
                    raise ValueError(f"Could not find candidate fitting within {possibilities}")
            elif block_to_guess_index == 5:
                if byte_index_to_guess > 5:
                    current_guessing[byte_index_to_guess] = ord('8')
                    potential_X = ord('}') % 16
                elif byte_index_to_guess == 5:
                    current_guessing[byte_index_to_guess] = ord('}')
                    potential_X = ord('"') % 16
                elif byte_index_to_guess == 4:
                    current_guessing[byte_index_to_guess] = ord('"')
                    potential_X = ord('}') % 16
                else:
                    known_blocks.append(bytes(current_guessing))
                    print(''.join(block.decode() for block in known_blocks))
                    raise ValueError("Could not find candidate")
    

    known_blocks.append(bytes(current_guessing))
    print(''.join(block.decode() for block in known_blocks))

After trying this a few time, it can be good to note that the verification that we do to validate or invalidate a possible X is insufficient, we have some false positives, that back propagates the error to the rest of the computation, making the result unusable. A solution could be to modify the script and test more bytes than just the current one and the last one of the plaintext. This is left as an exercise to the reader… ;)

This is not my proudest CTF flag, but well, even if the solution is dirty, it gives the same score in the end.