类型:Crypto
翻译自:https://github.com/p4-team/ctf/tree/master/2016-04-15-plaid-ctf/crypto_rabit
考察知识点:LSB-Oracle,rabin cipher
题目描述:
Just give me a bit, the least significant's enough. Just a second we’re not broken, just very, very insecure.
Running at rabit.pwning.xxx:7763
题目给了服务器文件,分析代码可以知道采用的是rabin加密 message^2 mod N 访问服务器后,服务器会发来N和加密过的PT,即CT=powmod(PT,2,N)。之后我们可以通过服务器请求我们发送的密文对应的明文的最低比特位(least-significant-bit (LSB))
对应的代码是:
while True:
req.sendall("Give a ciphertext: ")
x = long(recvline())
m = decrypt(x, p, q)
if m == None:
m = 0
req.sendall("lsb is {}\n".format(m % 2))
我们可以通过不断泄露最低比特位,采用类似二分搜索的算法获得最后的明文。
具体原理如下:
显然如果一个数A乘以2,那么结果B=2*A一定是偶数,它的LSB是0。如果我们用B模一个奇数N,那么会有两个结果:
- 如果B<N, 那么结果依旧是偶数,LSB 是0
- 如果B>N,那么结果必是奇数,LSB是1
在题目中,由于采用了rabin加密,N=p * q(p,q为大素数),因此N必然是奇数。这样如果我们让服务器告诉我们(2*PT)mod N的结果,根据不同的返回结果,我们可以有如下结论
- 如果LSB为0,那么2*PT <N,PT<N/2
- 如果LSB为1,那么2*PT >N,PT>N/2
这样我们就根据LSB,缩小了PT的范围。
进一步的,如果我们让服务器求 (4*PT) mod N的LSB,根据返回结果,我们能进一步知道:
- 如果LSB为0,那么4*PT<N, 如果PT< N/2 那么PT< N/4;如果PT>N/2 那么 PT< 3*N/4
- 如果LSB为1,那么4*PT>N, 如果PT< N/2 那么PT> N/4;如果PT>N/2 那么 PT> 3*N/4
通过这样的二分查找,我们可以不停逼近PT的上下界,最终使上下界不再变化。
由于服务器会先对我们发送的内容进行rabin解密,即对发送的内容做模N上的开平方运算,因此我们想让服务器求 (2PT) mod N,我们应当向服务器发送 4CT,这是因为:
sqrt_mod(4CT, N) = sqrt_mod(4,N)sqrt_mod(CT,N) = 2*PT mod N
通过脚本,自动化二分查找的过程:
def oracle(ciphertext, s):
print("sent ciphertext " + str(ciphertext))
s.sendall(str(ciphertext) + "\n")
data = recvline(s)
print("oracle response: " + data)
lsb = int(re.findall("lsb is (.*)", data)[0])
return lsb
def brute_flag(encrypted_flag, N, socket):
flag_lower_bound = 0
flag_upper_bound = N
mult = 0
ciphertext = (encrypted_flag * pow(4, mult)) % N
while flag_upper_bound > flag_lower_bound:
data = s.recv(512)
ciphertext = (ciphertext * 4) % N
mult += 1
print("main loop: " + data)
print("upper = %d" % flag_upper_bound)
print("upper flag = %s" % long_to_bytes(flag_upper_bound))
print("lower = %d" % flag_lower_bound)
print("lower flag = %s" % long_to_bytes(flag_lower_bound))
print("multiplier = %d" % mult)
if oracle(ciphertext, socket) == 0:
flag_upper_bound = (flag_upper_bound + flag_lower_bound) / 2
else:
flag_lower_bound = (flag_upper_bound + flag_lower_bound) / 2
return flag_upper_bound
就可以求出flag,需要注意的是,这道题并不需要完全求出800比特,因为后面很多都是padding,真正的flag只有前25字节。