ITFEST 2025 Finals Writeup - Cryptography
Personal writeup for ITFEST CTF 2025: Decrypt Only Revenge.
Decrypt Only Revenge
During the finals of ITFEST CTF 2025, I got the first solve on this challenge (first blood), and it remained as the only solve of this challenge.
Author: agoyy
We were given a single attachment, server.py.
from Crypto.Util.number import *
from Crypto.Random import random
from math import gcd
FLAG = open("flag.txt", "rb").read()
def lcm(a, b):
return a // gcd(a, b) * b
class Homomorphic:
def __init__(self, bits=512):
self.p = getPrime(bits)
self.q = getPrime(bits)
self.n = self.p * self.q
self.n2 = self.n * self.n
self.g = self.n + 1
self._lambda = lcm(self.p - 1, self.q - 1)
u = pow(self.g, self._lambda, self.n2)
L_u = (u - 1) // self.n
self.mu = inverse(L_u, self.n)
def gen(self, num):
return list(map(int, bin(random.getrandbits(num))[2:]))
def encrypt(self, m):
r = random.randrange(self.n - 1) + 1
return (pow(self.g, m, self.n2) * pow(r, self.n, self.n2)) % self.n2
def decrypt(self, c):
u = pow(c, self._lambda, self.n2)
L_u = (u - 1) // self.n
pt = (L_u * self.mu) % self.n
bin_pt = list(map(int, bin(pt)[2:]))
key = self.gen(pt.bit_length())
for i in range(len(key)):
bin_pt[i] ^= key[i]
return ''.join(list(map(str, bin_pt)))
def main():
h = Homomorphic(bits=512)
m = int.from_bytes(FLAG, 'big')
ct = h.encrypt(m)
print(f"N = {h.n}")
print(f"g = {h.g}")
print(f"enc(flag) = {ct}")
while True:
c = int(input("Ciphertext (in hex) : "), 16)
m = h.decrypt(c)
print(int(m,2))
if __name__ == "__main__":
main()
Approach
This is a Paillier cryptosystem. I also participated in another CTF called Meta4Sec CTF 2025 and it had the first version of this challenge, and this is a harder version of that challenge, hence the Revenge in the name.
Reference challenge writeup: Meta4Sec 2025 Final PDF
In that challenge, the only actual difference with this version is the oracle output, so we can apply the same logic here too.
Since this is a Paillier cryptosystem (g = n + 1), it has a property which is additive homomorphism, where:
Plus if we look at the decrypt() function:
def decrypt(self, c):
u = pow(c, self._lambda, self.n2)
L_u = (u - 1) // self.n
pt = (L_u * self.mu) % self.n
bin_pt = list(map(int, bin(pt)[2:]))
key = self.gen(pt.bit_length())
for i in range(len(key)):
bin_pt[i] ^= key[i]
return ''.join(list(map(str, bin_pt)))
The decryptor converts pt into an unpadded bitstring, then XORs a fresh random mask. The mask is intended to have pt.bit_length() bits, but because it is built with bin(getrandbits(k))[2:], leading zeros are dropped and the mask may be shorter. If the first mask bit exists and is 1, the MSB flips; otherwise it stays. Printing int(bits, 2) then makes the observed bit-length a noisy estimate of the true length, so we resample and take the maximum. We can say this challenge is essentially a noisy bit length oracle for pt.
The flip has a 50% chance to happen in every query. If we just repeat getting queries and then taking the maximum length we get the actual bit length of pt.
Knowing this, we can recover the flag by doing a MSB-peeling process.
- We can compute
g_inv = (n + 1)^(-1) mod n^2 = (1 - n) mod n^2since(1 + n)(1 - n) = 1 (mod n^2). - We can create 2 variables,
known = 0, and setG = 1(this tracksg_inv^add). - Create a variable
ct_probe, basicallyc' = c * G = E(m - known). Sample the oracle a few times and record the max bit lengthl. Then we know the next MSB is2^(l-1)of the remaining value. - Update the 2 variables we made:
known += 2^(l - 1)G = G * (g_inv)^(2^(l - 1)) mod n^2
- Repeat this iteration until
l = 0, thenknown = mand we recover the flag.
Solver
# eter
from Crypto.Util.number import *
from pwn import *
# from Pwn4Sage.pwn import *
context.log_level = 'info'
# hostport = '...'
# HOST = hostport.split()[1]
# PORT = int(hostport.split()[2])
def sample_len(r, ct_base, g_inv_pow_known, n2):
ct_rem = (ct_base * g_inv_pow_known) % n2
max_bits = 0
for _ in range(10):
r.sendlineafter(b"Ciphertext (in hex) : ", hex(ct_rem)[2:].encode())
line = r.recvline().strip()
try:
v = int(line)
except:
continue
bl = v.bit_length()
if bl > max_bits:
max_bits = bl
return max_bits
def main():
r = process(['python', 'server.py'])
r.recvuntil(b"N = ")
N = int(r.recvline().strip())
r.recvuntil(b"g = ")
g = int(r.recvline().strip())
r.recvuntil(b"enc(flag) = ")
ct_flag = int(r.recvline().strip())
n2 = N * N
g_inv = (1 - N) % n2 # (g)^{-1} mod n^2 for g = n+1
known = 0
g_inv_pow_known = 1
bit_pos = 1024
while bit_pos > 0:
rem_bitlen = sample_len(r, ct_flag, g_inv_pow_known, n2)
if rem_bitlen == 0:
break # recovered
bit_pos = rem_bitlen - 1
add = 1 << bit_pos
known += add
# update g_inv_pow_known *= g_inv^{add}
g_inv_pow_known = (g_inv_pow_known * pow(g_inv, add, n2)) % n2
log.info(f"Recovered bit at position {bit_pos}; known now has {known.bit_length()} bits")
print(long_to_bytes(known))
r.interactive()
if __name__ == '__main__':
main()
Local testing:

Result
