Module cryptanalysis.differential_cryptanalysis
Module for performing differential cryptanalysis on Substitution Permutation Network based ciphers.
Classes: - DifferentialCryptanalysis: Class for performing differential cryptanalysis.
Usage:
Import the differential_cryptanalysis
module to access the DifferentialCryptanalysis
class.
Expand source code
"""
Module for performing differential cryptanalysis on Substitution Permutation Network based ciphers.
Classes:
- DifferentialCryptanalysis: Class for performing differential cryptanalysis.
Usage:
Import the `differential_cryptanalysis` module to access the `DifferentialCryptanalysis` class.
"""
from itertools import product
from collections import Counter
import random
from .cryptanalysis import Cryptanalysis
__all__ = ["DifferentialCryptanalysis"]
class DifferentialCryptanalysis(Cryptanalysis):
"""
Class for performing differential cryptanalysis.
Methods:
- __init__: Initialize the differential cryptanalysis algorithm.
- find_keybits: Find the key bits using differential cryptanalysis.
- find_last_roundkey: Find the last round key using differential cryptanalysis.
- generate_encryption_pairs: Generate encryption pairs for differential cryptanalysis.
"""
def __init__(self, sbox, pbox, num_rounds):
"""
Initialize the DifferentialCryptanalysis algorithm.
Parameters:
- sbox: The substitution box used in the SPN
- pbox: The permutation box used in the SPN
- num_rounds: The number of rounds in the SPN
Notes:
- This method is called when creating an instance of the DifferentialCryptanalysis class.
"""
super().__init__(sbox, pbox, num_rounds, 'differential')
def find_keybits(self, out_mask, ct_pairs, known_keyblocks=()):
"""Finds the key bits based on the output mask and ciphertext pairs.
This method overrides the abstract `find_keybits` method in the `Cryptanalysis` class.
It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input.
It implements the logic to find the key bits based on the provided parameters.
Args:
out_mask (int): The output mask for the difference in encrypted pairs.
ct_pairs (list): A list of ciphertext pairs used for analysis.
known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.
Returns:
int: A value representing the most likely key bits
"""
out_blocks = self.int_to_list(out_mask)
active_blocks = [i for i, v in enumerate(out_blocks) if v and i not in known_keyblocks]
key_diffcounts = Counter()
for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
key = [0] * self.num_sbox
for i, v in zip(active_blocks, klst):
key[i] = v
key = self.list_to_int(key)
for ct1, ct2 in ct_pairs:
diff = self.dec_partial_last_noperm(
ct1, [key]) ^ self.dec_partial_last_noperm(
ct2, [key])
diff = self.int_to_list(diff)
key_diffcounts[key] += all(out_blocks[i] == diff[i] for i in active_blocks)
# key_diffcounts[key] += all(i==j for i,j in zip(out_blocks,diff))
topn = key_diffcounts.most_common(self.box_size)
for i, v in topn:
print(self.int_to_list(i), v)
return topn[0]
def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000):
"""Finds the last round key based on output masks.
This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class.
It takes a list of output masks, a cutoff value, and a multiple value as input.
It implements the logic to find the last round key based on the output masks and the specified parameters.
Args:
outmasks (list): A list of tuples of (input, output masks, bias) for which
the last round key needs to be found.
cutoff (int, optional): The cutoff value used for the maximum number of encryptions
called from oracle in determining the last round key. Defaults to 10000.
multiple (int, optional): The multiple indicating the size of the batch of values to be
encrypted at once used for generating encryption pairs. Defaults to 1000.
Returns:
list: The last round key determined based on the output masks.
"""
final_key = [None] * self.num_sbox
all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple)
for pt_ct_pairs, (_, out_mask, _) in zip(all_pt_ct_pairs, outmasks):
ct_pairs = [i[1] for i in pt_ct_pairs]
# print("out mask",self.int_to_list(out_mask))
k = self.find_keybits(out_mask, ct_pairs, [
i for i, v in enumerate(final_key) if v is not None])
kr = self.int_to_list(k[0])
print(kr)
for i, v in enumerate(self.int_to_list(out_mask)):
if v and final_key[i] is None:
final_key[i] = kr[i]
print(final_key)
print()
return final_key
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000):
"""Generates encryption pairs for a set of output masks.
This method overrides the abstract `generate_encryption_pairs` method in the `Cryptanalysis` class.
It takes a list of output masks, a cutoff value, and a multiple value as input.
It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.
Args:
outmasks (list): A list of tuples of (input_diff_mask, output_diff_mask and bias)
for which encryption pairs need to be generated.
cutoff (int, optional): The cutoff value of the max number of encryptions invoked. Defaults to 10000.
multiple (int, optional): The multiple indicating the size of the batch of values to be
encrypted at once used for generating encryption pairs. Defaults to 1000.
Returns:
list: A list of plaintext-ciphertext pairs for each output mask.
"""
all_pt_pairs = []
for inp_mask, _, bias in outmasks:
pt_pairs = []
new_encs = {} # new encryptions + seen encryptions
threshold = min(100 * int(1 / bias), cutoff)
# first pass, look for already existing pairs
for i in self.encryptions:
if len(pt_pairs) >= threshold:
break
if i in new_encs:
# already added the pair i.e i^inp_mask
continue
if i ^ inp_mask in self.encryptions:
new_encs[i] = self.encryptions[i]
new_encs[i ^ inp_mask] = self.encryptions[i ^ inp_mask]
pt_pairs.append((i, i ^ inp_mask))
for i in set(self.encryptions) - set(new_encs):
if len(pt_pairs) >= threshold:
break
# only add if we have exhausted our already encrypted pairs
new_encs[i] = self.encryptions[i]
# new_encs[i^inp_mask] = self.encrypt(i^inp_mask)
new_encs[i ^ inp_mask] = None # marked to be encrypted
pt_pairs.append((i, i ^ inp_mask))
self.encryptions.update(new_encs)
while len(pt_pairs) < threshold:
r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1)
if r in self.encryptions or r ^ inp_mask in self.encryptions:
continue
self.encryptions[r] = None
self.encryptions[r ^ inp_mask] = None
pt_pairs.append((r, r ^ inp_mask))
all_pt_pairs.append(pt_pairs)
self.update_encryptions(multiple=multiple)
all_pt_ct_pairs = []
for pt_pairs in all_pt_pairs:
pt_ct_pairs = []
for (p1, p2) in pt_pairs:
pt_ct_pairs.append(
((p1, p2), (self.encryptions[p1], self.encryptions[p2])))
all_pt_ct_pairs.append(pt_ct_pairs)
return all_pt_ct_pairs
Classes
class DifferentialCryptanalysis (sbox, pbox, num_rounds)
-
Class for performing differential cryptanalysis.
Methods: - init: Initialize the differential cryptanalysis algorithm. - find_keybits: Find the key bits using differential cryptanalysis. - find_last_roundkey: Find the last round key using differential cryptanalysis. - generate_encryption_pairs: Generate encryption pairs for differential cryptanalysis.
Initialize the DifferentialCryptanalysis algorithm.
Parameters: - sbox: The substitution box used in the SPN - pbox: The permutation box used in the SPN - num_rounds: The number of rounds in the SPN
Notes: - This method is called when creating an instance of the DifferentialCryptanalysis class.
Expand source code
class DifferentialCryptanalysis(Cryptanalysis): """ Class for performing differential cryptanalysis. Methods: - __init__: Initialize the differential cryptanalysis algorithm. - find_keybits: Find the key bits using differential cryptanalysis. - find_last_roundkey: Find the last round key using differential cryptanalysis. - generate_encryption_pairs: Generate encryption pairs for differential cryptanalysis. """ def __init__(self, sbox, pbox, num_rounds): """ Initialize the DifferentialCryptanalysis algorithm. Parameters: - sbox: The substitution box used in the SPN - pbox: The permutation box used in the SPN - num_rounds: The number of rounds in the SPN Notes: - This method is called when creating an instance of the DifferentialCryptanalysis class. """ super().__init__(sbox, pbox, num_rounds, 'differential') def find_keybits(self, out_mask, ct_pairs, known_keyblocks=()): """Finds the key bits based on the output mask and ciphertext pairs. This method overrides the abstract `find_keybits` method in the `Cryptanalysis` class. It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input. It implements the logic to find the key bits based on the provided parameters. Args: out_mask (int): The output mask for the difference in encrypted pairs. ct_pairs (list): A list of ciphertext pairs used for analysis. known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list. Returns: int: A value representing the most likely key bits """ out_blocks = self.int_to_list(out_mask) active_blocks = [i for i, v in enumerate(out_blocks) if v and i not in known_keyblocks] key_diffcounts = Counter() for klst in product(range(len(self.sbox)), repeat=len(active_blocks)): key = [0] * self.num_sbox for i, v in zip(active_blocks, klst): key[i] = v key = self.list_to_int(key) for ct1, ct2 in ct_pairs: diff = self.dec_partial_last_noperm( ct1, [key]) ^ self.dec_partial_last_noperm( ct2, [key]) diff = self.int_to_list(diff) key_diffcounts[key] += all(out_blocks[i] == diff[i] for i in active_blocks) # key_diffcounts[key] += all(i==j for i,j in zip(out_blocks,diff)) topn = key_diffcounts.most_common(self.box_size) for i, v in topn: print(self.int_to_list(i), v) return topn[0] def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000): """Finds the last round key based on output masks. This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class. It takes a list of output masks, a cutoff value, and a multiple value as input. It implements the logic to find the last round key based on the output masks and the specified parameters. Args: outmasks (list): A list of tuples of (input, output masks, bias) for which the last round key needs to be found. cutoff (int, optional): The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000. multiple (int, optional): The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000. Returns: list: The last round key determined based on the output masks. """ final_key = [None] * self.num_sbox all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple) for pt_ct_pairs, (_, out_mask, _) in zip(all_pt_ct_pairs, outmasks): ct_pairs = [i[1] for i in pt_ct_pairs] # print("out mask",self.int_to_list(out_mask)) k = self.find_keybits(out_mask, ct_pairs, [ i for i, v in enumerate(final_key) if v is not None]) kr = self.int_to_list(k[0]) print(kr) for i, v in enumerate(self.int_to_list(out_mask)): if v and final_key[i] is None: final_key[i] = kr[i] print(final_key) print() return final_key def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000): """Generates encryption pairs for a set of output masks. This method overrides the abstract `generate_encryption_pairs` method in the `Cryptanalysis` class. It takes a list of output masks, a cutoff value, and a multiple value as input. It generates plaintext-ciphertext pairs for each output mask based on the specified parameters. Args: outmasks (list): A list of tuples of (input_diff_mask, output_diff_mask and bias) for which encryption pairs need to be generated. cutoff (int, optional): The cutoff value of the max number of encryptions invoked. Defaults to 10000. multiple (int, optional): The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000. Returns: list: A list of plaintext-ciphertext pairs for each output mask. """ all_pt_pairs = [] for inp_mask, _, bias in outmasks: pt_pairs = [] new_encs = {} # new encryptions + seen encryptions threshold = min(100 * int(1 / bias), cutoff) # first pass, look for already existing pairs for i in self.encryptions: if len(pt_pairs) >= threshold: break if i in new_encs: # already added the pair i.e i^inp_mask continue if i ^ inp_mask in self.encryptions: new_encs[i] = self.encryptions[i] new_encs[i ^ inp_mask] = self.encryptions[i ^ inp_mask] pt_pairs.append((i, i ^ inp_mask)) for i in set(self.encryptions) - set(new_encs): if len(pt_pairs) >= threshold: break # only add if we have exhausted our already encrypted pairs new_encs[i] = self.encryptions[i] # new_encs[i^inp_mask] = self.encrypt(i^inp_mask) new_encs[i ^ inp_mask] = None # marked to be encrypted pt_pairs.append((i, i ^ inp_mask)) self.encryptions.update(new_encs) while len(pt_pairs) < threshold: r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1) if r in self.encryptions or r ^ inp_mask in self.encryptions: continue self.encryptions[r] = None self.encryptions[r ^ inp_mask] = None pt_pairs.append((r, r ^ inp_mask)) all_pt_pairs.append(pt_pairs) self.update_encryptions(multiple=multiple) all_pt_ct_pairs = [] for pt_pairs in all_pt_pairs: pt_ct_pairs = [] for (p1, p2) in pt_pairs: pt_ct_pairs.append( ((p1, p2), (self.encryptions[p1], self.encryptions[p2]))) all_pt_ct_pairs.append(pt_ct_pairs) return all_pt_ct_pairs
Ancestors
- Cryptanalysis
- SPN
- abc.ABC
Methods
def find_keybits(self, out_mask, ct_pairs, known_keyblocks=())
-
Finds the key bits based on the output mask and ciphertext pairs.
This method overrides the abstract
find_keybits
method in theCryptanalysis
class. It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input. It implements the logic to find the key bits based on the provided parameters.Args
out_mask
:int
- The output mask for the difference in encrypted pairs.
ct_pairs
:list
- A list of ciphertext pairs used for analysis.
known_keyblocks
:list
, optional- A list of known key blocks. Defaults to an empty list.
Returns
int
- A value representing the most likely key bits
Expand source code
def find_keybits(self, out_mask, ct_pairs, known_keyblocks=()): """Finds the key bits based on the output mask and ciphertext pairs. This method overrides the abstract `find_keybits` method in the `Cryptanalysis` class. It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input. It implements the logic to find the key bits based on the provided parameters. Args: out_mask (int): The output mask for the difference in encrypted pairs. ct_pairs (list): A list of ciphertext pairs used for analysis. known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list. Returns: int: A value representing the most likely key bits """ out_blocks = self.int_to_list(out_mask) active_blocks = [i for i, v in enumerate(out_blocks) if v and i not in known_keyblocks] key_diffcounts = Counter() for klst in product(range(len(self.sbox)), repeat=len(active_blocks)): key = [0] * self.num_sbox for i, v in zip(active_blocks, klst): key[i] = v key = self.list_to_int(key) for ct1, ct2 in ct_pairs: diff = self.dec_partial_last_noperm( ct1, [key]) ^ self.dec_partial_last_noperm( ct2, [key]) diff = self.int_to_list(diff) key_diffcounts[key] += all(out_blocks[i] == diff[i] for i in active_blocks) # key_diffcounts[key] += all(i==j for i,j in zip(out_blocks,diff)) topn = key_diffcounts.most_common(self.box_size) for i, v in topn: print(self.int_to_list(i), v) return topn[0]
def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000)
-
Finds the last round key based on output masks.
This method overrides the abstract
find_last_roundkey
method in theCryptanalysis
class. It takes a list of output masks, a cutoff value, and a multiple value as input. It implements the logic to find the last round key based on the output masks and the specified parameters.Args
outmasks
:list
- A list of tuples of (input, output masks, bias) for which the last round key needs to be found.
cutoff
:int
, optional- The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000.
multiple
:int
, optional- The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.
Returns
list
- The last round key determined based on the output masks.
Expand source code
def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000): """Finds the last round key based on output masks. This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class. It takes a list of output masks, a cutoff value, and a multiple value as input. It implements the logic to find the last round key based on the output masks and the specified parameters. Args: outmasks (list): A list of tuples of (input, output masks, bias) for which the last round key needs to be found. cutoff (int, optional): The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000. multiple (int, optional): The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000. Returns: list: The last round key determined based on the output masks. """ final_key = [None] * self.num_sbox all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple) for pt_ct_pairs, (_, out_mask, _) in zip(all_pt_ct_pairs, outmasks): ct_pairs = [i[1] for i in pt_ct_pairs] # print("out mask",self.int_to_list(out_mask)) k = self.find_keybits(out_mask, ct_pairs, [ i for i, v in enumerate(final_key) if v is not None]) kr = self.int_to_list(k[0]) print(kr) for i, v in enumerate(self.int_to_list(out_mask)): if v and final_key[i] is None: final_key[i] = kr[i] print(final_key) print() return final_key
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000)
-
Generates encryption pairs for a set of output masks.
This method overrides the abstract
generate_encryption_pairs
method in theCryptanalysis
class. It takes a list of output masks, a cutoff value, and a multiple value as input. It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.Args
outmasks
:list
- A list of tuples of (input_diff_mask, output_diff_mask and bias) for which encryption pairs need to be generated.
cutoff
:int
, optional- The cutoff value of the max number of encryptions invoked. Defaults to 10000.
multiple
:int
, optional- The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.
Returns
list
- A list of plaintext-ciphertext pairs for each output mask.
Expand source code
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000): """Generates encryption pairs for a set of output masks. This method overrides the abstract `generate_encryption_pairs` method in the `Cryptanalysis` class. It takes a list of output masks, a cutoff value, and a multiple value as input. It generates plaintext-ciphertext pairs for each output mask based on the specified parameters. Args: outmasks (list): A list of tuples of (input_diff_mask, output_diff_mask and bias) for which encryption pairs need to be generated. cutoff (int, optional): The cutoff value of the max number of encryptions invoked. Defaults to 10000. multiple (int, optional): The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000. Returns: list: A list of plaintext-ciphertext pairs for each output mask. """ all_pt_pairs = [] for inp_mask, _, bias in outmasks: pt_pairs = [] new_encs = {} # new encryptions + seen encryptions threshold = min(100 * int(1 / bias), cutoff) # first pass, look for already existing pairs for i in self.encryptions: if len(pt_pairs) >= threshold: break if i in new_encs: # already added the pair i.e i^inp_mask continue if i ^ inp_mask in self.encryptions: new_encs[i] = self.encryptions[i] new_encs[i ^ inp_mask] = self.encryptions[i ^ inp_mask] pt_pairs.append((i, i ^ inp_mask)) for i in set(self.encryptions) - set(new_encs): if len(pt_pairs) >= threshold: break # only add if we have exhausted our already encrypted pairs new_encs[i] = self.encryptions[i] # new_encs[i^inp_mask] = self.encrypt(i^inp_mask) new_encs[i ^ inp_mask] = None # marked to be encrypted pt_pairs.append((i, i ^ inp_mask)) self.encryptions.update(new_encs) while len(pt_pairs) < threshold: r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1) if r in self.encryptions or r ^ inp_mask in self.encryptions: continue self.encryptions[r] = None self.encryptions[r ^ inp_mask] = None pt_pairs.append((r, r ^ inp_mask)) all_pt_pairs.append(pt_pairs) self.update_encryptions(multiple=multiple) all_pt_ct_pairs = [] for pt_pairs in all_pt_pairs: pt_ct_pairs = [] for (p1, p2) in pt_pairs: pt_ct_pairs.append( ((p1, p2), (self.encryptions[p1], self.encryptions[p2]))) all_pt_ct_pairs.append(pt_ct_pairs) return all_pt_ct_pairs
Inherited members