Skip to main content

[HITCON 2023] Share write-up

This is a writeup for the Share challenge of the HITCON 2023 CTF. We analyse a vulnerable implementation of Shamir’s Secret Sharing, and recover the secret with access to only a part of the shares.

Presentation of the challenge

This challenge was the most solved and most accessible challenge of all crypto challenges of the HITCON 2023 CTF. Thanks to maple3142 for the challenge.

Before downloading the challenge, the following text is prompted:

I hope I actually implemented Shamir Secret Sharing correctly this year. I am pretty sure you won’t be able to guess my secret even when I give you all but one share.

The first step of the challenge was then to read about Shamir’s Secret Sharing. Samir’s Secret Sharing (SSS in the rest of this post), is a scheme based on polynomial interpolation. Polynomial interpolation is the fact that any given polynomial $P$ of degree $N$
$P_X = \sum_{k=0}^{N} a_{k}X^{k}$
can be factorized if we know $N + 1$ disctinct points on this polynomial. But if you know only $N$ points, there are still an infinity of possible polynomials.

The code associated with the challenge is the following:

#!/usr/bin/env python3
from Crypto.Util.number import isPrime, getRandomRange, bytes_to_long
from typing import List
import os, signal

class SecretSharing:
    def __init__(self, p: int, n: int, secret: int):
        self.p = p
        self.n = n
        self.poly = [secret] + [getRandomRange(0, self.p - 1) for _ in range(n - 1)]

    def evaluate(self, x: int) -> int:
        return (
            sum([self.poly[i] * pow(x, i, self.p) for i in range(len(self.poly))])
            % self.p

    def get_shares(self) -> List[int]:
        return [self.evaluate(i + 1) for i in range(self.n)]

if __name__ == "__main__":
    secret = bytes_to_long(os.urandom(32))
    while True:
        p = int(input("p = "))
        n = int(input("n = "))
        if isPrime(p) and int(13.37) < n < p:
            shares = SecretSharing(p, n, secret).get_shares()
            print("shares =", shares[:-1])
    if int(input("secret = ")) == secret:
        print(open("flag.txt", "r").read().strip())

We can see that the server generate a 32-bytes secret, then asks us for a prime p and an integer n greater than 14 and smaller than p. It will then generate a random polynomial of degree at most n-1, in $\mathbb{Z}/{p}\mathbb{Z}[X]$, whose free coefficient $a_0$ is the secret, and returns us n-1 points of the polynomial (evaluated in 1, 2, 3, …, n-1).

The vulnerability in the implentation

If SSS was properly implemented, this would not be enough for us to get the secret, which is the key to the flag. Hopefully, there is a subtle mistake in the implementation. The coefficients of the polynomial are generated between $0$ and $p-2$, instead of $0$ and $p-1$ as they should be. This means that we can try to bruteforce $a_0$.

Bruteforcing $a_0$ mod $p$

Indeed, if we try to interpolate the polynomial with $a_0 = 0$, and we find a coefficient $a_i, i \in [1,N-1]$ equals to p-1, that means that $a_0$ cannot be 0. Because we know that the polynomial we are looking for has all it’s coefficients between 0 and p-2.
If we try again and again with all $a_0 \in [0,p]$, we may be able to exclude all possibilities for $a_0$ but one. (if there is more than one possible $a_0$ left at the end of the iteration, we can try again by requesting new shares). This would allow us to know the value of $a_0$ mod p.

But well, if we stopped there, we would still have roughly $2^{32*8} \div p$ possibilities to explore. That is a lot.

Finding the secret

But we can also do this for multiple ps. Then, how do we combine our remainders for different ps to find secret ? That is where the Chinese remainder theorem comes into play. Quoting from the Wikipedia’s page:

In mathematics, the Chinese remainder theorem states that if one knows the remainders of the Euclidean division of an integer n by several integers, then one can determine uniquely the remainder of the division of n by the product of these integers, under the condition that the divisors are pairwise coprime (no two divisors share a common factor other than 1).

Our integers here are obviously pairwise coprime, because they are all distinct primes. And we know for a fact that $secret < 2^{64 \times 8} = 2^{256}$ because it is a 64 bytes integer. This means that if we do this for enough primes $(p_i)_{i \in [0,m]}$ such that $\prod_{i \in [0,m]}p_i > 2^{256}$, we will have found secret. We could also have done it only for one prime p that would be bigger than $2^{256}$, but then, we would need too many iterations to discard all impossible moduli.


So to the sum up, let’s recap our plan:

  1. Find several primes such that their products is larger than $2^{256}$.
  2. For each prime $p$:
    1. Request a new share with $n = 14$.
    2. Use interpolation to figure out what secret cannot be equal to (modulo p).
    3. Continue until we know what secret is equal to modulo p.
  3. Apply the Chinese remainder theorem to get secret.

This gives us the following code:

from sage.all import GF, crt
from pwn import remote
from Cryptodome.Util.number import sieve_base
import json

# Step 1
primes = []
product = 1
i = 0
while product < 2**256:
    if sieve_base[i] < 14:
        i += 1
    product *= sieve_base[i]
    i += 1
print(f"Selected primes: {primes}")
# Step 2
soc = remote("localhost", 11111)
# Step 2.1
def get_shares(p: int) -> list[int]:
    query = f"{p}\n14\n"
    shares = json.loads(soc.recvline().strip().decode()[17:])
    return shares

remainders = []
for (i, p) in enumerate(primes):
    print(f"Computing modulo for prime {p} ({i+1}/{len(primes)})")
    possible_mods = set(range(p))
    gf = GF(p)["x"]
    # Step 2.3
    while len(possible_mods) != 1:
        # Step 2.1
        shares = get_shares(p)
        # Step 2.2
        for possible_mod in list(possible_mods): # copies the set
            polynomial = gf.lagrange_polynomial(enumerate([possible_mod] + shares))
            if p-1 in polynomial.coefficients(sparse=False):

# Step 3
secret = int(crt(remainders, primes))

Delusion and the solution

Unfortunately, this is too slow, the server has a 30-seconds time out, and I have only time to do half of the computation on my computer during this timeframe. But we can run a quick profiling with

$ python -m cProfile

        2588392 function calls (2550532 primitive calls) in 30.992 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1323/1    0.004    0.000   31.183   31.183 {built-in method builtins.exec}
        1    0.191    0.191   31.183   31.183<module>)
      641    0.012    0.000   26.530    0.041
      641    0.004    0.000   26.440    0.041
      641    0.043    0.000   26.429    0.041
     1920    0.006    0.000   26.342    0.014
     1920    0.012    0.000   26.199    0.014
     1280    0.021    0.000   26.060    0.020
     1280    0.009    0.000   25.882    0.020
     1280   25.874    0.020   25.874    0.020 {method 'recv' of '_socket.socket' objects}
    11623    0.787    0.000    3.345    0.000
    11623    1.404    0.000    2.167    0.000
      123    0.002    0.000    1.371    0.011<module>)
   1989/7    0.007    0.000    0.680    0.097 <frozen importlib._bootstrap>:1165(_find_and_load)
   1965/7    0.006    0.000    0.680    0.097 <frozen importlib._bootstrap>:1120(_find_and_load_unlocked)
  1832/10    0.004    0.000    0.679    0.068 <frozen importlib._bootstrap>:666(_load_unlocked)

We can see that most of our time is spent in recvline, and only 20% of the excution time is spend actually computing stuff.
A solution could be to try a bigger n. The computation of the Lagrangian polynomials would certainly be slower, but there may be an higher chance to have a coefficient equal to p-1. Then, we would need a lower number of shares.

Another way to fix it can be simply to lower the time it takes to get shares. Looking at the server code and see that we can actually query multiple lists of shares at the same time. To do so, we just have to send a bunch of {p}\n14\n at the same time(after replacing p with our value), and sequentially read the response. I did this by implementing a ShareProvider class, that gets the shares 10 by 10.

class ShareProvider:
    def __init__(self, p: int):
        self.p = p
        self.shares = []

    def next_share(self) -> int:
        if not self.shares:
            query = f"{p}\n14\n" * ShareProvider.SHARE_BUFFER
            for i in range(ShareProvider.SHARE_BUFFER):
        return self.shares.pop()

Then, we just need to modify the old step 2.1 with:

# Step 2.1
shares = sp.next_share()

It now finishes in time, and if we look at the profiling again:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1323/1    0.004    0.000   13.308   13.308 {built-in method builtins.exec}
        1    0.244    0.244   13.308   13.308<module>)
     1787    0.018    0.000    7.354    0.004
     1971    0.006    0.000    7.284    0.004
     1971    0.049    0.000    7.262    0.004
     2449    0.006    0.000    7.153    0.003
     2449    0.009    0.000    6.983    0.003
      479    0.005    0.000    6.842    0.014
      479    0.002    0.000    6.803    0.014
      479    6.801    0.014    6.801    0.014 {method 'recv' of '_socket.socket' objects}
    46970    1.047    0.000    4.439    0.000
    46970    1.917    0.000    2.878    0.000

We still spend much of our time making requests, and we could further optimize this by introducing parallelism, but this is good enough for now, we can enjoy our flag !