Skip to main content

[HITCON 2023] Share write-up

This is a writeup for the Share challenge of the HITCON 2023 CTF. We analyse a vulnerable implementation of Shamir’s Secret Sharing, and recover the secret with access to only a part of the shares.

Presentation of the challenge

This challenge was the most solved and most accessible challenge of all crypto challenges of the HITCON 2023 CTF. Thanks to maple3142 for the challenge.

Before downloading the challenge, the following text is prompted:

I hope I actually implemented Shamir Secret Sharing correctly this year. I am pretty sure you won’t be able to guess my secret even when I give you all but one share.

The first step of the challenge was then to read about Shamir’s Secret Sharing. Samir’s Secret Sharing (SSS in the rest of this post), is a scheme based on polynomial interpolation. Polynomial interpolation is the fact that any given polynomial $P$ of degree $N$
$P_X = \sum_{k=0}^{N} a_{k}X^{k}$
can be factorized if we know $N + 1$ disctinct points on this polynomial. But if you know only $N$ points, there are still an infinity of possible polynomials.

The code associated with the challenge is the following:

#!/usr/bin/env python3
from Crypto.Util.number import isPrime, getRandomRange, bytes_to_long
from typing import List
import os, signal


class SecretSharing:
    def __init__(self, p: int, n: int, secret: int):
        self.p = p
        self.n = n
        self.poly = [secret] + [getRandomRange(0, self.p - 1) for _ in range(n - 1)]

    def evaluate(self, x: int) -> int:
        return (
            sum([self.poly[i] * pow(x, i, self.p) for i in range(len(self.poly))])
            % self.p
        )

    def get_shares(self) -> List[int]:
        return [self.evaluate(i + 1) for i in range(self.n)]


if __name__ == "__main__":
    signal.alarm(30)
    secret = bytes_to_long(os.urandom(32))
    while True:
        p = int(input("p = "))
        n = int(input("n = "))
        if isPrime(p) and int(13.37) < n < p:
            shares = SecretSharing(p, n, secret).get_shares()
            print("shares =", shares[:-1])
        else:
            break
    if int(input("secret = ")) == secret:
        print(open("flag.txt", "r").read().strip())

We can see that the server generate a 32-bytes secret, then asks us for a prime p and an integer n greater than 14 and smaller than p. It will then generate a random polynomial of degree at most n-1, in $\mathbb{Z}/{p}\mathbb{Z}[X]$, whose free coefficient $a_0$ is the secret, and returns us n-1 points of the polynomial (evaluated in 1, 2, 3, …, n-1).

The vulnerability in the implentation

If SSS was properly implemented, this would not be enough for us to get the secret, which is the key to the flag. Hopefully, there is a subtle mistake in the implementation. The coefficients of the polynomial are generated between $0$ and $p-2$, instead of $0$ and $p-1$ as they should be. This means that we can try to bruteforce $a_0$.

Bruteforcing $a_0$ mod $p$

Indeed, if we try to interpolate the polynomial with $a_0 = 0$, and we find a coefficient $a_i, i \in [1,N-1]$ equals to p-1, that means that $a_0$ cannot be 0. Because we know that the polynomial we are looking for has all it’s coefficients between 0 and p-2.
If we try again and again with all $a_0 \in [0,p]$, we may be able to exclude all possibilities for $a_0$ but one. (if there is more than one possible $a_0$ left at the end of the iteration, we can try again by requesting new shares). This would allow us to know the value of $a_0$ mod p.

But well, if we stopped there, we would still have roughly $2^{32*8} \div p$ possibilities to explore. That is a lot.

Finding the secret

But we can also do this for multiple ps. Then, how do we combine our remainders for different ps to find secret ? That is where the Chinese remainder theorem comes into play. Quoting from the Wikipedia’s page:

In mathematics, the Chinese remainder theorem states that if one knows the remainders of the Euclidean division of an integer n by several integers, then one can determine uniquely the remainder of the division of n by the product of these integers, under the condition that the divisors are pairwise coprime (no two divisors share a common factor other than 1).

Our integers here are obviously pairwise coprime, because they are all distinct primes. And we know for a fact that $secret < 2^{64 \times 8} = 2^{256}$ because it is a 64 bytes integer. This means that if we do this for enough primes $(p_i)_{i \in [0,m]}$ such that $\prod_{i \in [0,m]}p_i > 2^{256}$, we will have found secret. We could also have done it only for one prime p that would be bigger than $2^{256}$, but then, we would need too many iterations to discard all impossible moduli.

Implementation

So to the sum up, let’s recap our plan:

  1. Find several primes such that their products is larger than $2^{256}$.
  2. For each prime $p$:
    1. Request a new share with $n = 14$.
    2. Use interpolation to figure out what secret cannot be equal to (modulo p).
    3. Continue until we know what secret is equal to modulo p.
  3. Apply the Chinese remainder theorem to get secret.

This gives us the following code:

from sage.all import GF, crt
from pwn import remote
from Cryptodome.Util.number import sieve_base
import json

# Step 1
primes = []
product = 1
i = 0
while product < 2**256:
    if sieve_base[i] < 14:
        i += 1
        continue
    product *= sieve_base[i]
    primes.append(sieve_base[i])
    i += 1
print(f"Selected primes: {primes}")
# Step 2
soc = remote("localhost", 11111)
# Step 2.1
def get_shares(p: int) -> list[int]:
    query = f"{p}\n14\n"
    soc.send(query.encode())
    shares = json.loads(soc.recvline().strip().decode()[17:])
    return shares

remainders = []
for (i, p) in enumerate(primes):
    print(f"Computing modulo for prime {p} ({i+1}/{len(primes)})")
    possible_mods = set(range(p))
    gf = GF(p)["x"]
    # Step 2.3
    while len(possible_mods) != 1:
        # Step 2.1
        shares = get_shares(p)
        # Step 2.2
        for possible_mod in list(possible_mods): # copies the set
            polynomial = gf.lagrange_polynomial(enumerate([possible_mod] + shares))
            if p-1 in polynomial.coefficients(sparse=False):
                possible_mods.remove(possible_mod)
    remainders.append(possible_mods.pop())

# Step 3
secret = int(crt(remainders, primes))
soc.send(f"0\n0\n{secret}\n".encode())
print(soc.recvline().strip())

Delusion and the solution

Unfortunately, this is too slow, the server has a 30-seconds time out, and I have only time to do half of the computation on my computer during this timeframe. But we can run a quick profiling with

$ python -m cProfile solve.py

        2588392 function calls (2550532 primitive calls) in 30.992 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1323/1    0.004    0.000   31.183   31.183 {built-in method builtins.exec}
        1    0.191    0.191   31.183   31.183 solve.py:1(<module>)
      641    0.012    0.000   26.530    0.041 solve.py:23(get_shares)
      641    0.004    0.000   26.440    0.041 tube.py:463(recvline)
      641    0.043    0.000   26.429    0.041 tube.py:280(recvuntil)
     1920    0.006    0.000   26.342    0.014 tube.py:73(recv)
     1920    0.012    0.000   26.199    0.014 tube.py:165(_recv)
     1280    0.021    0.000   26.060    0.020 tube.py:130(_fillbuffer)
     1280    0.009    0.000   25.882    0.020 sock.py:33(recv_raw)
     1280   25.874    0.020   25.874    0.020 {method 'recv' of '_socket.socket' objects}
    11623    0.787    0.000    3.345    0.000 polynomial_ring.py:2190(lagrange_polynomial)
    11623    1.404    0.000    2.167    0.000 polynomial_ring.py:2100(divided_difference)
      123    0.002    0.000    1.371    0.011 all.py:1(<module>)
   1989/7    0.007    0.000    0.680    0.097 <frozen importlib._bootstrap>:1165(_find_and_load)
   1965/7    0.006    0.000    0.680    0.097 <frozen importlib._bootstrap>:1120(_find_and_load_unlocked)
  1832/10    0.004    0.000    0.679    0.068 <frozen importlib._bootstrap>:666(_load_unlocked)
  ...

We can see that most of our time is spent in recvline, and only 20% of the excution time is spend actually computing stuff.
A solution could be to try a bigger n. The computation of the Lagrangian polynomials would certainly be slower, but there may be an higher chance to have a coefficient equal to p-1. Then, we would need a lower number of shares.

Another way to fix it can be simply to lower the time it takes to get shares. Looking at the server code and see that we can actually query multiple lists of shares at the same time. To do so, we just have to send a bunch of {p}\n14\n at the same time(after replacing p with our value), and sequentially read the response. I did this by implementing a ShareProvider class, that gets the shares 10 by 10.

class ShareProvider:
    SHARE_BUFFER = 10
    def __init__(self, p: int):
        self.p = p
        self.shares = []

    def next_share(self) -> int:
        if not self.shares:
            query = f"{p}\n14\n" * ShareProvider.SHARE_BUFFER
            soc.send(query.encode())
            for i in range(ShareProvider.SHARE_BUFFER):
                self.shares.append(json.loads(soc.recvline().strip().decode()[17:]))
        return self.shares.pop()

Then, we just need to modify the old step 2.1 with:

# Step 2.1
shares = sp.next_share()

It now finishes in time, and if we look at the profiling again:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1323/1    0.004    0.000   13.308   13.308 {built-in method builtins.exec}
        1    0.244    0.244   13.308   13.308 solve.py:1(<module>)
     1787    0.018    0.000    7.354    0.004 solve.py:28(next_share)
     1971    0.006    0.000    7.284    0.004 tube.py:463(recvline)
     1971    0.049    0.000    7.262    0.004 tube.py:280(recvuntil)
     2449    0.006    0.000    7.153    0.003 tube.py:73(recv)
     2449    0.009    0.000    6.983    0.003 tube.py:165(_recv)
      479    0.005    0.000    6.842    0.014 tube.py:130(_fillbuffer)
      479    0.002    0.000    6.803    0.014 sock.py:33(recv_raw)
      479    6.801    0.014    6.801    0.014 {method 'recv' of '_socket.socket' objects}
    46970    1.047    0.000    4.439    0.000 polynomial_ring.py:2190(lagrange_polynomial)
    46970    1.917    0.000    2.878    0.000 polynomial_ring.py:2100(divided_difference)

We still spend much of our time making requests, and we could further optimize this by introducing parallelism, but this is good enough for now, we can enjoy our flag !