Package cryptanalysis

Expand source code
from .spn import SPN, rotate_left, gen_pbox
from .characteristic_searcher import CharacteristicSearcher
from .linear_cryptanalysis import LinearCryptanalysis
from .differential_cryptanalysis import DifferentialCryptanalysis
from .utils import parity, calculate_linear_bias, calculate_difference_table
__all__ = ["CharacteristicSearcher", "LinearCryptanalysis",
           "DifferentialCryptanalysis", "SPN", "rotate_left", "gen_pbox",
           "parity", "calculate_linear_bias", "calculate_difference_table"]

Sub-modules

cryptanalysis.characteristic_searcher

characteristic_searcher …

cryptanalysis.cryptanalysis

Module for performing cryptanalysis on Substitution Permutation Network cipher …

cryptanalysis.differential_cryptanalysis

Module for performing differential cryptanalysis on Substitution Permutation Network based ciphers …

cryptanalysis.linear_cryptanalysis

Module for performing linear cryptanalysis on Substitution Permutation Network based ciphers …

cryptanalysis.spn

The cryptanalysis.spn module implements the Substitution-Permutation Network (SPN) encryption algorithm …

cryptanalysis.utils

The cryptanalysis.utils module provides utility functions for various operations …

Functions

def calculate_difference_table(sbox)

Calculates the difference distribution table for an S-box.

This method calculates the difference table for an S-box. It iterates over all possible input and input difference pairs and counts the number of output differences for each input difference.

Args

sbox : list
A list of integers representing the S-box.

Returns

Counter
A Counter dictionary containing the count of output differences for each input difference.
Expand source code
def calculate_difference_table(sbox):
    """Calculates the difference distribution table for an S-box.

    This method calculates the difference table for an S-box. It iterates
    over all possible input and input difference pairs and counts the number of
    output differences for each input difference.

    Args:
        sbox (list): A list of integers representing the S-box.

    Returns:
        Counter: A Counter dictionary containing the count of output differences for each input difference.
    """
    n = len(sbox)
    bias = Counter()
    for inp_diff in tqdm(range(n), desc='calculating sbox differences'):
        for inp in range(n):
            out_diff = sbox[inp] ^ sbox[inp ^ inp_diff]
            bias[(inp_diff, out_diff)] += 1
    return bias
def calculate_linear_bias(sbox, no_sign=True, fraction=False)

Calculates the linear bias of an S-box.

This method calculates the linear bias of an S-box. It iterates over all possible input and output mask pairs and computes the linear bias using the Cryptanalysis.parity method.

Args

sbox : list
A list of integers representing the S-box.
no_sign : bool, optional
If True, the absolute value of the bias is returned. Defaults to True.
fraction : bool, optional
If True, the bias is returned as a fraction. Defaults to False.

Returns

Counter
A Counter dictionary containing the linear biases for each input and output mask pair.
Expand source code
def calculate_linear_bias(sbox, no_sign=True, fraction=False):
    """Calculates the linear bias of an S-box.

    This method calculates the linear bias of an S-box. It iterates over
    all possible input and output mask pairs and computes the linear bias using
    the Cryptanalysis.parity method.

    Args:
        sbox (list): A list of integers representing the S-box.
        no_sign (bool, optional): If True, the absolute value of the bias is returned. Defaults to True.
        fraction (bool, optional): If True, the bias is returned as a fraction. Defaults to False.

    Returns:
        Counter: A Counter dictionary containing the linear biases for each input and output mask pair.
    """
    n = len(sbox)
    bias = Counter({(i, j): -(n // 2) for i in range(n) for j in range(n)})
    for imask in tqdm(range(n), desc='calculating sbox bias'):
        for omask in range(n):
            for i in range(n):
                bias[(imask, omask)] += parity((sbox[i] & omask) ^ (i & imask)) ^ 1
    if no_sign:
        for i in bias:
            bias[i] = abs(bias[i])
    if fraction:
        for i in bias:
            bias[i] /= n
    return bias
def gen_pbox(s, n)

Generate a balanced permutation box for an SPN.

Parameters

s : int
Number of bits per S-box.
n : int
Number of S-boxes.

Returns

list of int
The generated P-box.
Expand source code
def gen_pbox(s, n):
    """
    Generate a balanced permutation box for an SPN.

    Parameters
    ----------
    s : int
        Number of bits per S-box.
    n : int
        Number of S-boxes.

    Returns
    -------
    list of int
        The generated P-box.

    """
    return [(s * i + j) % (n * s) for j in range(s) for i in range(n)]
def parity(x)

Calculates the parity of an integer.

This method calculates the parity of an integer by counting the number of set bits in the binary representation of the integer. It returns 0 if the number of set bits is even, and 1 otherwise.

Args

x : int
The input value for which the parity is calculated.

Returns

int
0 if the number of set bits is even, 1 otherwise.
Expand source code
def parity(x):
    """Calculates the parity of an integer.

    This method calculates the parity of an integer by counting the number
    of set bits in the binary representation of the integer. It returns 0 if the
    number of set bits is even, and 1 otherwise.

    Args:
        x (int): The input value for which the parity is calculated.

    Returns:
        int: 0 if the number of set bits is even, 1 otherwise.
    """
    res = 0
    while x:
        res ^= 1
        x &= (x - 1)
    return res
def rotate_left(val, shift, mod)

Rotate the bits of the value to the left by the shift amount.

Parameters

val : int
The value to be rotated.
shift : int
The number of places to shift the value to the left.
mod : int
The modulo to be applied on the result.

Returns

int
The rotated value.

Notes

The function rotates the bits of the value to the left by the shift amount, wrapping the bits that overflow. The result is then masked by (1<<mod)-1 to only keep the mod number of least significant bits.

Expand source code
def rotate_left(val, shift, mod):
    """
    Rotate the bits of the value to the left by the shift amount.

    Parameters
    ----------
    val : int
        The value to be rotated.
    shift : int
        The number of places to shift the value to the left.
    mod : int
        The modulo to be applied on the result.

    Returns
    -------
    int
        The rotated value.

    Notes
    -----
    The function rotates the bits of the value to the left by the shift amount,
    wrapping the bits that overflow. The result is then masked by (1<<mod)-1
    to only keep the mod number of least significant bits.

    """
    shift = shift % mod
    return (val << shift | val >> (mod - shift)) & ((1 << mod) - 1)

Classes

class CharacteristicSearcher (sbox, pbox, num_rounds, mode='linear')

A class for finding characteristics (linear or differential) of a substitution permutation network with provided S-box and P-box with a given number of rounds.

Attributes

sbox
A list representing the substitution box.
pbox
A list representing the permutation box.
num_rounds
An integer representing the number of rounds.
block_size
An integer representing the number of bits in the block.
box_size
An integer representing the size of the S-box in bits.
num_blocks
An integer representing the number of sboxes in a block
mode
A string representing the mode, which can be 'linear' or 'differential'.
bias
A Counter dictionary representing linear or differential bias of sbox input/output pairs
solutions
A dictionary containing list of valid characteristic masks for a given set of included and excluded blocks
solver
SMT solver (optimize) instance to search the characteristics

Initializes the CharacteristicSolver with the given sbox, pbox, num_rounds and mode.

Args

sbox : list
The substitution box.
pbox : list
The permutation box.
num_rounds : int
The number of rounds.
mode : str, optional
The mode of operation. Defaults to 'linear'.
Expand source code
class CharacteristicSearcher:
    """A class for finding characteristics (linear or differential) of a substitution
    permutation network with provided S-box and P-box with a given number of rounds.

    Attributes:
        sbox: A list representing the substitution box.
        pbox: A list representing the permutation box.
        num_rounds: An integer representing the number of rounds.
        block_size: An integer representing the number of bits in the block.
        box_size: An integer representing the size of the S-box in bits.
        num_blocks: An integer representing the number of sboxes in a block
        mode: A string representing the mode, which can be 'linear' or 'differential'.
        bias: A Counter dictionary representing linear or differential bias
              of sbox input/output pairs
        solutions: A dictionary containing list of valid characteristic masks for a given
            set of included and excluded blocks
        solver: SMT solver (optimize) instance to search the characteristics
    """

    def __init__(self, sbox, pbox, num_rounds, mode='linear'):
        """Initializes the CharacteristicSolver with the given sbox, pbox, num_rounds and mode.

        Args:
            sbox (list): The substitution box.
            pbox (list): The permutation box.
            num_rounds (int): The number of rounds.
            mode (str, optional): The mode of operation. Defaults to 'linear'.
        """
        self.sbox = sbox
        self.pbox = pbox
        self.num_rounds = num_rounds
        self.block_size = len(pbox)
        self.box_size = int(log2(len(sbox)))
        self.num_blocks = len(pbox) // self.box_size
        self.mode = mode
        if mode == 'linear':
            self.bias = calculate_linear_bias(sbox)
        elif mode == 'differential':
            self.bias = calculate_difference_table(sbox)
        self.solutions = defaultdict(list)
        self.solver = Optimize()
        self.prune_level = 0
        self.sboxf = None
        self.inps = None
        self.oups = None
        self.bv_inp_masks = None
        self.bv_oup_masks = None
        self.objectives = None

    def initialize_sbox_structure(self):
        """Initializes the S-box structure for the cryptographic solver.

        This method sets up the structure of the S-box by creating an optimized solver,
        initializing input and output bit vectors for each round, and adding
        constraints for the solver. It also creates a concatenated view of the input
        and output layers for further processing.
        """
        n = self.box_size
        self.inps = [[BitVec('r{}_i{}'.format(r, i), n) for i in range(
            self.num_blocks)] for r in range(self.num_rounds + 1)]
        self.oups = [[BitVec('r{}_o{}'.format(r, i), n) for i in range(
            self.num_blocks)] for r in range(self.num_rounds)]
        # permutation of output of sboxes are inputs of next round
        for i in range(self.num_rounds):
            if self.num_blocks == 1:
                self.solver.add(self.bitvec_permutation(
                    self.oups[i][0], self.inps[i + 1][0]))
            else:
                self.solver.add(self.bitvec_permutation(
                    Concat(self.oups[i]), Concat(self.inps[i + 1])))
        # all first layer inputs should not be 0
        self.solver.add(
            Not(And(*[self.inps[0][i] == 0 for i in range(self.num_blocks)])))
        for r in range(self.num_rounds):
            for i in range(self.num_blocks):
                # if sbox has input, it should have output
                self.solver.add(
                    Implies(
                        self.inps[r][i] != 0,
                        self.oups[r][i] != 0))
                # if sbox has no input it should not have any output
                self.solver.add(
                    Implies(
                        self.inps[r][i] == 0,
                        self.oups[r][i] == 0))

        # just a concatanated view of the input and output layers
        if self.num_blocks == 1:
            self.bv_inp_masks = [self.inps[i][0]
                                 for i in range(self.num_rounds + 1)]
            self.bv_oup_masks = [self.oups[i][0]
                                 for i in range(self.num_rounds)]
        else:
            self.bv_inp_masks = [Concat(self.inps[i])
                                 for i in range(self.num_rounds + 1)]
            self.bv_oup_masks = [Concat(self.oups[i])
                                 for i in range(self.num_rounds)]

    def bitvec_permutation(self, inp, oup):
        """Performs bit vector permutation based on pbox.

        Args:
            inp (BitVec): The input bit vector.
            oup (BitVec): The output bit vector.

        Returns:
            list: A list of constraints for the permutation.
        """
        pn = len(self.pbox)
        constraints = []
        for i, v in enumerate(self.pbox):
            constraints.append(
                Extract(pn - 1 - i, pn - 1 - i, inp) ==
                Extract(pn - 1 - v, pn - 1 - v, oup)
            )
        return constraints

    def initialize_objectives(self):
        """Initializes the objective functions for the cryptographic solver.

        The method sets up four types of objective functions: 'original_linear',
        'reduced', 'differential', and 'linear'. These objective functions are
        used to guide the solver in finding the optimal solution. Each objective
        function is associated with a lambda function that calculates the objective
        value for a given number of rounds.
        'reduced' objective is called for both linear and differential search
        Other objective functions are just there for reference and easy evaluation
        of bias directly from the model
        """
        self.objectives = {
            # the actual objective, which is just product of bias [0,1/2]
            'original_linear': lambda rounds: 2**(self.num_blocks * rounds - 1) * Product([self.sboxf(
                self.inps[i // self.num_blocks][i % self.num_blocks],
                self.oups[i // self.num_blocks][i % self.num_blocks])
                for i in range(self.num_blocks * rounds)
            ]),
            # reducing optimization problem of product to sums
            'reduced': lambda rounds: sum([
                self.sboxf(
                    self.inps[i // self.num_blocks][i % self.num_blocks],
                    self.oups[i // self.num_blocks][i % self.num_blocks])
                for i in range(self.num_blocks * rounds)
            ]),
            # objective when the input biases are [0,2**n] just the final
            # division
            'differential': lambda rounds: Product([
                self.sboxf(
                    self.inps[i // self.num_blocks][i % self.num_blocks],
                    self.oups[i // self.num_blocks][i % self.num_blocks])
                for i in range(self.num_blocks * rounds)
            ]) / ((2**self.box_size)**(self.num_blocks * rounds)),
            'linear': lambda rounds: 2**(self.num_blocks * rounds - 1) * Product([
                self.sboxf(
                    self.inps[i // self.num_blocks][i % self.num_blocks],
                    self.oups[i // self.num_blocks][i % self.num_blocks])
                for i in range(self.num_blocks * rounds)
            ]) / ((2**self.box_size)**(self.num_blocks * rounds))
        }

    def add_bias_constraints(self, prune_level):
        """Adds bias constraints to the solver based on the biases of the S-box.

        This method adds constraints to the solver that are based on the biases of the S-box.
        If the bias of a particular input-output pair is greater than or equal to 2**prune_level,
        the method adds a constraint that the S-box function of the pair is equal to the bias.
        Otherwise, it adds a constraint that the S-box function of the pair is 0. This helps in
        pruning the search space of the solver.

        Args:
            prune_level (int): The level at which to prune the biases.
        """
        for i in range(2**self.box_size):
            for j in range(2**self.box_size):
                # just some pruning of very small biases
                if self.bias[(i, j)] >= 2**(prune_level):
                    self.solver.add(self.sboxf(i, j) == self.bias[(i, j)])
                else:
                    self.solver.add(self.sboxf(i, j) == 0)

        for r in range(self.num_rounds):
            for i in range(self.num_blocks):
                # skip taking input/outputs with no bias
                self.solver.add(
                    Implies(
                        And(self.inps[r][i] != 0, self.oups[r][i] != 0),
                        self.sboxf(self.inps[r][i], self.oups[r][i]) != 0
                    )
                )

    def init_characteristic_solver(self, prune_level=-1):
        """Initializes the S-box structure, S-box function, objective functions, and pruning level.

        This method initializes the structure of the S-box, the S-box function,
        and the objective functions for the solver. It also sets the pruning level
        for the solver. If no pruning level is provided, the method will search for
        the best pruning level.

        Args:
            prune_level (int, optional): The level at which to prune the biases.
            If not provided or less than 0, the method will search for the best pruning level.
        """
        self.initialize_sbox_structure()
        self.sboxf = Function(
            'sbox', BitVecSort(
                self.box_size), BitVecSort(
                self.box_size), RealSort())
        self.initialize_objectives()
        assert self.solver.check()

        if prune_level < 0:
            print("searching best pruning level")
            low, high = 0, len(self.sbox) // 4
            while low <= high:
                mid = (low + high) // 2
                print("trying pruning", mid)
                self.solver.push()
                self.solver.set(timeout=10000)
                self.add_bias_constraints(mid)
                if self.solver.check() == sat:
                    print("success")
                    low = mid + 1
                else:
                    print("failure")
                    high = mid - 1
                self.solver.pop()
            self.solver.set(timeout=0)
            print("best pruning", high)
            self.prune_level = high
            self.add_bias_constraints(high)
        else:
            self.add_bias_constraints(prune_level)
            if self.solver.check() == sat:
                self.prune_level = prune_level
            else:
                print("Provided pruning level unsat, searching optimal pruning")
                self.init_characteristic_solver(-1)  # search best pruning

    def solve_for_blocks(self, include_blocks=(), exclude_blocks=(),
                         num_rounds=0,
                         num_sols=1,
                         display_paths=True):
        """Solves the characteristic for the specified blocks and maximizes the objective function.

            This method searches the characteristic for the specified blocks,
            maximizes the objective function, and returns the solutions.
            The blocks to include and exclude in the characteristic can be specified.
            The number of rounds and the number of solutions can also be specified.

            Args:
                include_blocks (list, optional): The blocks to definitely include in the characteristic.
                exclude_blocks (list, optional): The blocks to definitely exclude in the characteristic.
                num_rounds (int, optional): The number of rounds for which to solve the characteristic.
                                             If not provided or 0, the number of rounds will be set to the
                                             number of rounds of the solver.
                num_sols (int, optional): The number of solutions to return.
                display_paths (bool, optional): Whether to display the paths of the solutions.

            Returns:
                list: A list of tuples. Each tuple contains the input masks, the output masks, and the
                      calculated bias for a solution.
            """
        if num_rounds == 0:
            num_rounds = self.num_rounds
        else:
            # cap to initialized struct
            num_rounds = min(self.num_rounds, num_rounds)
        while len(self.solver.objectives()):
            self.solver.pop()  # remove any previous include/exclude block constraints
        self.solver.push()  # set this as the checkpoint
        # specify which blocks to definitely include in the characteristic
        for i in include_blocks:
            self.solver.add(self.inps[num_rounds - 1][i] != 0)
        # specify which blocks to definitely exclude in the characteristic
        for i in exclude_blocks:
            self.solver.add(self.inps[num_rounds - 1][i] == 0)
        # print(include_blocks, exclude_blocks)
        # if a block is neither in include_blocks or exclude_blocks
        # the solver finds the best path which may or may not set it to active
        self.solver.maximize(self.objectives['reduced'](num_rounds))
        solutions = self.get_masks(num_rounds, num_sols, display_paths)
        self.solutions[(tuple(sorted(include_blocks)),
                        tuple(sorted(exclude_blocks)))].extend(solutions)
        return [(inp_masks[0], inp_masks[-1], calc_bias)
                for inp_masks, _, calc_bias, _ in solutions]

    def search_best_masks(self, tolerance=1, choose_best=10, display_paths=True):
        """Searches for the best masks with the highest total bias and limited undiscovered active blocks.

        This method searches for the best masks with the highest total bias and a limited number
        of undiscovered active blocks.

        Args:
            tolerance (int, optional): The maximum number of undiscovered active blocks allowed.
            choose_best (int, optional): The number of best masks to choose from.
            display_paths (bool, optional): Whether to display the characteristic paths
                                        (containing the bits involved) of the solutions.

        Returns:
            list: A list of tuples. Each tuple contains the input masks, the output masks, and the
                  total bias for a solution.
        """

        self.init_characteristic_solver()
        nr = self.num_rounds
        discovered = [False for _ in range(self.num_blocks)]

        def istolerable(x):
            return sum((not i) and j
                       for i, j in zip(discovered, x[3])) in range(1, tolerance + 1)
        masks = []
        while self.solver.check() == sat:
            curr_masks = self.get_masks(self.num_rounds, choose_best, display_paths=False)
            for i in curr_masks:
                self.solutions[i[2]].append(i)
            curr_masks = list(filter(istolerable, curr_masks))
            if len(curr_masks) > 0:
                inp_masks, oup_masks, total_bias, active = max(
                    curr_masks, key=lambda x: (x[2], -sum(x[3])))
                if display_paths:
                    self.print_bitrelations(inp_masks, oup_masks)
                    print("total bias:", total_bias)
                    print()
                masks.append((inp_masks[0], inp_masks[nr - 1], total_bias))
                for i, v in enumerate(discovered):
                    if (not v) and active[i]:
                        discovered[i] = True
                print("discovered", "".join(map(lambda x: str(int(x)), discovered)))
                # dont discover biases where all the active blocks come from discovered blocks
                # i.e. if all the active blocks come from discovered blocks,
                # it means, all the undiscovered blocks are inactive
                # i.e it should not be the case where all the undiscovered blocks are
                # inactive i.e 0
                self.solver.add(Not(And(
                    [self.inps[nr - 1][i] == 0 for i, v in enumerate(discovered) if not v]
                )))
        return masks

    def search_exclusive_masks(self, prune_level=-1, repeat=1):
        """Searches for the masks for each block by including only one block and excluding all the others.

        This method searches for the masks for each block by including only one block and excluding
        all the others.

        Args:
            prune_level (int, optional): The level at which to prune the biases.
            repeat (int, optional): The number of times to repeat the search for each block.

        Returns:
            list: A list of tuples. Each tuple contains the input masks, the output masks, and the
                  total bias for a solution.
        """
        self.init_characteristic_solver(prune_level)
        masks = []
        for i in range(self.num_blocks):
            include_blocks = {i}
            exclude_blocks = set(range(self.num_blocks)) - include_blocks
            masks.extend(self.solve_for_blocks(include_blocks, exclude_blocks, num_sols=repeat))
        return masks

    def get_masks(self, num_rounds, n=1, display_paths=True):
        """Returns the input masks, output masks, total bias, and active blocks of the solutions.

        This method returns the input masks, output masks, total bias, and active blocks of the solutions.

        Args:
            num_rounds (int): The number of rounds for which to get the masks.
            n (int, optional): The number of masks to get.
            display_paths (bool, optional): Whether to display the paths of the solutions.

        Returns:
            list: A list of tuples. Each tuple contains the input masks, the output masks, the total bias,
                  and the active blocks for a solution.
        """
        masks = []
        for m in islice(all_smt(self.solver, [self.bv_inp_masks[num_rounds - 1]]), n):
            inp_masks = [m.eval(i).as_long()
                         for i in self.bv_inp_masks[:num_rounds]]
            oup_masks = [m.eval(i).as_long()
                         for i in self.bv_oup_masks[:num_rounds]]
            total_bias = m.eval(
                self.objectives[self.mode](num_rounds)).as_fraction()
            active = [m.eval(i).as_long() != 0 for i in self.inps[num_rounds - 1]]
            if display_paths:
                self.print_bitrelations(inp_masks, oup_masks)
                print("total bias:", total_bias)
                print()
            masks.append((inp_masks, oup_masks, total_bias, active))
        return masks

    def print_bitrelations(self, inp_masks, out_masks):
        """
        Print the input and output masks of a block cipher in a formatted manner.

        :param inp_masks: List of integers, input masks for each round.
        :param out_masks: List of integers, output masks for each round.
        """
        s = self.box_size
        n = self.num_blocks * s

        def bin_sep(val):
            v = bin(val)[2:].zfill(n)
            return "|".join(v[i:i + s] for i in range(0, n, s))

        rounds = len(out_masks)
        for i in range(rounds):
            imask, omask = inp_masks[i], out_masks[i]
            print(bin_sep(imask))
            print(' '.join(['-' * s] * (n // s)))
            print(bin_sep(omask))
            print()
        print(bin_sep(inp_masks[-1]))

Methods

def add_bias_constraints(self, prune_level)

Adds bias constraints to the solver based on the biases of the S-box.

This method adds constraints to the solver that are based on the biases of the S-box. If the bias of a particular input-output pair is greater than or equal to 2**prune_level, the method adds a constraint that the S-box function of the pair is equal to the bias. Otherwise, it adds a constraint that the S-box function of the pair is 0. This helps in pruning the search space of the solver.

Args

prune_level : int
The level at which to prune the biases.
Expand source code
def add_bias_constraints(self, prune_level):
    """Adds bias constraints to the solver based on the biases of the S-box.

    This method adds constraints to the solver that are based on the biases of the S-box.
    If the bias of a particular input-output pair is greater than or equal to 2**prune_level,
    the method adds a constraint that the S-box function of the pair is equal to the bias.
    Otherwise, it adds a constraint that the S-box function of the pair is 0. This helps in
    pruning the search space of the solver.

    Args:
        prune_level (int): The level at which to prune the biases.
    """
    for i in range(2**self.box_size):
        for j in range(2**self.box_size):
            # just some pruning of very small biases
            if self.bias[(i, j)] >= 2**(prune_level):
                self.solver.add(self.sboxf(i, j) == self.bias[(i, j)])
            else:
                self.solver.add(self.sboxf(i, j) == 0)

    for r in range(self.num_rounds):
        for i in range(self.num_blocks):
            # skip taking input/outputs with no bias
            self.solver.add(
                Implies(
                    And(self.inps[r][i] != 0, self.oups[r][i] != 0),
                    self.sboxf(self.inps[r][i], self.oups[r][i]) != 0
                )
            )
def bitvec_permutation(self, inp, oup)

Performs bit vector permutation based on pbox.

Args

inp : BitVec
The input bit vector.
oup : BitVec
The output bit vector.

Returns

list
A list of constraints for the permutation.
Expand source code
def bitvec_permutation(self, inp, oup):
    """Performs bit vector permutation based on pbox.

    Args:
        inp (BitVec): The input bit vector.
        oup (BitVec): The output bit vector.

    Returns:
        list: A list of constraints for the permutation.
    """
    pn = len(self.pbox)
    constraints = []
    for i, v in enumerate(self.pbox):
        constraints.append(
            Extract(pn - 1 - i, pn - 1 - i, inp) ==
            Extract(pn - 1 - v, pn - 1 - v, oup)
        )
    return constraints
def get_masks(self, num_rounds, n=1, display_paths=True)

Returns the input masks, output masks, total bias, and active blocks of the solutions.

This method returns the input masks, output masks, total bias, and active blocks of the solutions.

Args

num_rounds : int
The number of rounds for which to get the masks.
n : int, optional
The number of masks to get.
display_paths : bool, optional
Whether to display the paths of the solutions.

Returns

list
A list of tuples. Each tuple contains the input masks, the output masks, the total bias, and the active blocks for a solution.
Expand source code
def get_masks(self, num_rounds, n=1, display_paths=True):
    """Returns the input masks, output masks, total bias, and active blocks of the solutions.

    This method returns the input masks, output masks, total bias, and active blocks of the solutions.

    Args:
        num_rounds (int): The number of rounds for which to get the masks.
        n (int, optional): The number of masks to get.
        display_paths (bool, optional): Whether to display the paths of the solutions.

    Returns:
        list: A list of tuples. Each tuple contains the input masks, the output masks, the total bias,
              and the active blocks for a solution.
    """
    masks = []
    for m in islice(all_smt(self.solver, [self.bv_inp_masks[num_rounds - 1]]), n):
        inp_masks = [m.eval(i).as_long()
                     for i in self.bv_inp_masks[:num_rounds]]
        oup_masks = [m.eval(i).as_long()
                     for i in self.bv_oup_masks[:num_rounds]]
        total_bias = m.eval(
            self.objectives[self.mode](num_rounds)).as_fraction()
        active = [m.eval(i).as_long() != 0 for i in self.inps[num_rounds - 1]]
        if display_paths:
            self.print_bitrelations(inp_masks, oup_masks)
            print("total bias:", total_bias)
            print()
        masks.append((inp_masks, oup_masks, total_bias, active))
    return masks
def init_characteristic_solver(self, prune_level=-1)

Initializes the S-box structure, S-box function, objective functions, and pruning level.

This method initializes the structure of the S-box, the S-box function, and the objective functions for the solver. It also sets the pruning level for the solver. If no pruning level is provided, the method will search for the best pruning level.

Args

prune_level : int, optional
The level at which to prune the biases.

If not provided or less than 0, the method will search for the best pruning level.

Expand source code
def init_characteristic_solver(self, prune_level=-1):
    """Initializes the S-box structure, S-box function, objective functions, and pruning level.

    This method initializes the structure of the S-box, the S-box function,
    and the objective functions for the solver. It also sets the pruning level
    for the solver. If no pruning level is provided, the method will search for
    the best pruning level.

    Args:
        prune_level (int, optional): The level at which to prune the biases.
        If not provided or less than 0, the method will search for the best pruning level.
    """
    self.initialize_sbox_structure()
    self.sboxf = Function(
        'sbox', BitVecSort(
            self.box_size), BitVecSort(
            self.box_size), RealSort())
    self.initialize_objectives()
    assert self.solver.check()

    if prune_level < 0:
        print("searching best pruning level")
        low, high = 0, len(self.sbox) // 4
        while low <= high:
            mid = (low + high) // 2
            print("trying pruning", mid)
            self.solver.push()
            self.solver.set(timeout=10000)
            self.add_bias_constraints(mid)
            if self.solver.check() == sat:
                print("success")
                low = mid + 1
            else:
                print("failure")
                high = mid - 1
            self.solver.pop()
        self.solver.set(timeout=0)
        print("best pruning", high)
        self.prune_level = high
        self.add_bias_constraints(high)
    else:
        self.add_bias_constraints(prune_level)
        if self.solver.check() == sat:
            self.prune_level = prune_level
        else:
            print("Provided pruning level unsat, searching optimal pruning")
            self.init_characteristic_solver(-1)  # search best pruning
def initialize_objectives(self)

Initializes the objective functions for the cryptographic solver.

The method sets up four types of objective functions: 'original_linear', 'reduced', 'differential', and 'linear'. These objective functions are used to guide the solver in finding the optimal solution. Each objective function is associated with a lambda function that calculates the objective value for a given number of rounds. 'reduced' objective is called for both linear and differential search Other objective functions are just there for reference and easy evaluation of bias directly from the model

Expand source code
def initialize_objectives(self):
    """Initializes the objective functions for the cryptographic solver.

    The method sets up four types of objective functions: 'original_linear',
    'reduced', 'differential', and 'linear'. These objective functions are
    used to guide the solver in finding the optimal solution. Each objective
    function is associated with a lambda function that calculates the objective
    value for a given number of rounds.
    'reduced' objective is called for both linear and differential search
    Other objective functions are just there for reference and easy evaluation
    of bias directly from the model
    """
    self.objectives = {
        # the actual objective, which is just product of bias [0,1/2]
        'original_linear': lambda rounds: 2**(self.num_blocks * rounds - 1) * Product([self.sboxf(
            self.inps[i // self.num_blocks][i % self.num_blocks],
            self.oups[i // self.num_blocks][i % self.num_blocks])
            for i in range(self.num_blocks * rounds)
        ]),
        # reducing optimization problem of product to sums
        'reduced': lambda rounds: sum([
            self.sboxf(
                self.inps[i // self.num_blocks][i % self.num_blocks],
                self.oups[i // self.num_blocks][i % self.num_blocks])
            for i in range(self.num_blocks * rounds)
        ]),
        # objective when the input biases are [0,2**n] just the final
        # division
        'differential': lambda rounds: Product([
            self.sboxf(
                self.inps[i // self.num_blocks][i % self.num_blocks],
                self.oups[i // self.num_blocks][i % self.num_blocks])
            for i in range(self.num_blocks * rounds)
        ]) / ((2**self.box_size)**(self.num_blocks * rounds)),
        'linear': lambda rounds: 2**(self.num_blocks * rounds - 1) * Product([
            self.sboxf(
                self.inps[i // self.num_blocks][i % self.num_blocks],
                self.oups[i // self.num_blocks][i % self.num_blocks])
            for i in range(self.num_blocks * rounds)
        ]) / ((2**self.box_size)**(self.num_blocks * rounds))
    }
def initialize_sbox_structure(self)

Initializes the S-box structure for the cryptographic solver.

This method sets up the structure of the S-box by creating an optimized solver, initializing input and output bit vectors for each round, and adding constraints for the solver. It also creates a concatenated view of the input and output layers for further processing.

Expand source code
def initialize_sbox_structure(self):
    """Initializes the S-box structure for the cryptographic solver.

    This method sets up the structure of the S-box by creating an optimized solver,
    initializing input and output bit vectors for each round, and adding
    constraints for the solver. It also creates a concatenated view of the input
    and output layers for further processing.
    """
    n = self.box_size
    self.inps = [[BitVec('r{}_i{}'.format(r, i), n) for i in range(
        self.num_blocks)] for r in range(self.num_rounds + 1)]
    self.oups = [[BitVec('r{}_o{}'.format(r, i), n) for i in range(
        self.num_blocks)] for r in range(self.num_rounds)]
    # permutation of output of sboxes are inputs of next round
    for i in range(self.num_rounds):
        if self.num_blocks == 1:
            self.solver.add(self.bitvec_permutation(
                self.oups[i][0], self.inps[i + 1][0]))
        else:
            self.solver.add(self.bitvec_permutation(
                Concat(self.oups[i]), Concat(self.inps[i + 1])))
    # all first layer inputs should not be 0
    self.solver.add(
        Not(And(*[self.inps[0][i] == 0 for i in range(self.num_blocks)])))
    for r in range(self.num_rounds):
        for i in range(self.num_blocks):
            # if sbox has input, it should have output
            self.solver.add(
                Implies(
                    self.inps[r][i] != 0,
                    self.oups[r][i] != 0))
            # if sbox has no input it should not have any output
            self.solver.add(
                Implies(
                    self.inps[r][i] == 0,
                    self.oups[r][i] == 0))

    # just a concatanated view of the input and output layers
    if self.num_blocks == 1:
        self.bv_inp_masks = [self.inps[i][0]
                             for i in range(self.num_rounds + 1)]
        self.bv_oup_masks = [self.oups[i][0]
                             for i in range(self.num_rounds)]
    else:
        self.bv_inp_masks = [Concat(self.inps[i])
                             for i in range(self.num_rounds + 1)]
        self.bv_oup_masks = [Concat(self.oups[i])
                             for i in range(self.num_rounds)]
def print_bitrelations(self, inp_masks, out_masks)

Print the input and output masks of a block cipher in a formatted manner.

:param inp_masks: List of integers, input masks for each round. :param out_masks: List of integers, output masks for each round.

Expand source code
def print_bitrelations(self, inp_masks, out_masks):
    """
    Print the input and output masks of a block cipher in a formatted manner.

    :param inp_masks: List of integers, input masks for each round.
    :param out_masks: List of integers, output masks for each round.
    """
    s = self.box_size
    n = self.num_blocks * s

    def bin_sep(val):
        v = bin(val)[2:].zfill(n)
        return "|".join(v[i:i + s] for i in range(0, n, s))

    rounds = len(out_masks)
    for i in range(rounds):
        imask, omask = inp_masks[i], out_masks[i]
        print(bin_sep(imask))
        print(' '.join(['-' * s] * (n // s)))
        print(bin_sep(omask))
        print()
    print(bin_sep(inp_masks[-1]))
def search_best_masks(self, tolerance=1, choose_best=10, display_paths=True)

Searches for the best masks with the highest total bias and limited undiscovered active blocks.

This method searches for the best masks with the highest total bias and a limited number of undiscovered active blocks.

Args

tolerance : int, optional
The maximum number of undiscovered active blocks allowed.
choose_best : int, optional
The number of best masks to choose from.
display_paths : bool, optional
Whether to display the characteristic paths (containing the bits involved) of the solutions.

Returns

list
A list of tuples. Each tuple contains the input masks, the output masks, and the total bias for a solution.
Expand source code
def search_best_masks(self, tolerance=1, choose_best=10, display_paths=True):
    """Searches for the best masks with the highest total bias and limited undiscovered active blocks.

    This method searches for the best masks with the highest total bias and a limited number
    of undiscovered active blocks.

    Args:
        tolerance (int, optional): The maximum number of undiscovered active blocks allowed.
        choose_best (int, optional): The number of best masks to choose from.
        display_paths (bool, optional): Whether to display the characteristic paths
                                    (containing the bits involved) of the solutions.

    Returns:
        list: A list of tuples. Each tuple contains the input masks, the output masks, and the
              total bias for a solution.
    """

    self.init_characteristic_solver()
    nr = self.num_rounds
    discovered = [False for _ in range(self.num_blocks)]

    def istolerable(x):
        return sum((not i) and j
                   for i, j in zip(discovered, x[3])) in range(1, tolerance + 1)
    masks = []
    while self.solver.check() == sat:
        curr_masks = self.get_masks(self.num_rounds, choose_best, display_paths=False)
        for i in curr_masks:
            self.solutions[i[2]].append(i)
        curr_masks = list(filter(istolerable, curr_masks))
        if len(curr_masks) > 0:
            inp_masks, oup_masks, total_bias, active = max(
                curr_masks, key=lambda x: (x[2], -sum(x[3])))
            if display_paths:
                self.print_bitrelations(inp_masks, oup_masks)
                print("total bias:", total_bias)
                print()
            masks.append((inp_masks[0], inp_masks[nr - 1], total_bias))
            for i, v in enumerate(discovered):
                if (not v) and active[i]:
                    discovered[i] = True
            print("discovered", "".join(map(lambda x: str(int(x)), discovered)))
            # dont discover biases where all the active blocks come from discovered blocks
            # i.e. if all the active blocks come from discovered blocks,
            # it means, all the undiscovered blocks are inactive
            # i.e it should not be the case where all the undiscovered blocks are
            # inactive i.e 0
            self.solver.add(Not(And(
                [self.inps[nr - 1][i] == 0 for i, v in enumerate(discovered) if not v]
            )))
    return masks
def search_exclusive_masks(self, prune_level=-1, repeat=1)

Searches for the masks for each block by including only one block and excluding all the others.

This method searches for the masks for each block by including only one block and excluding all the others.

Args

prune_level : int, optional
The level at which to prune the biases.
repeat : int, optional
The number of times to repeat the search for each block.

Returns

list
A list of tuples. Each tuple contains the input masks, the output masks, and the total bias for a solution.
Expand source code
def search_exclusive_masks(self, prune_level=-1, repeat=1):
    """Searches for the masks for each block by including only one block and excluding all the others.

    This method searches for the masks for each block by including only one block and excluding
    all the others.

    Args:
        prune_level (int, optional): The level at which to prune the biases.
        repeat (int, optional): The number of times to repeat the search for each block.

    Returns:
        list: A list of tuples. Each tuple contains the input masks, the output masks, and the
              total bias for a solution.
    """
    self.init_characteristic_solver(prune_level)
    masks = []
    for i in range(self.num_blocks):
        include_blocks = {i}
        exclude_blocks = set(range(self.num_blocks)) - include_blocks
        masks.extend(self.solve_for_blocks(include_blocks, exclude_blocks, num_sols=repeat))
    return masks
def solve_for_blocks(self, include_blocks=(), exclude_blocks=(), num_rounds=0, num_sols=1, display_paths=True)

Solves the characteristic for the specified blocks and maximizes the objective function.

This method searches the characteristic for the specified blocks, maximizes the objective function, and returns the solutions. The blocks to include and exclude in the characteristic can be specified. The number of rounds and the number of solutions can also be specified.

Args

include_blocks : list, optional
The blocks to definitely include in the characteristic.
exclude_blocks : list, optional
The blocks to definitely exclude in the characteristic.
num_rounds : int, optional
The number of rounds for which to solve the characteristic. If not provided or 0, the number of rounds will be set to the number of rounds of the solver.
num_sols : int, optional
The number of solutions to return.
display_paths : bool, optional
Whether to display the paths of the solutions.

Returns

list
A list of tuples. Each tuple contains the input masks, the output masks, and the calculated bias for a solution.
Expand source code
def solve_for_blocks(self, include_blocks=(), exclude_blocks=(),
                     num_rounds=0,
                     num_sols=1,
                     display_paths=True):
    """Solves the characteristic for the specified blocks and maximizes the objective function.

        This method searches the characteristic for the specified blocks,
        maximizes the objective function, and returns the solutions.
        The blocks to include and exclude in the characteristic can be specified.
        The number of rounds and the number of solutions can also be specified.

        Args:
            include_blocks (list, optional): The blocks to definitely include in the characteristic.
            exclude_blocks (list, optional): The blocks to definitely exclude in the characteristic.
            num_rounds (int, optional): The number of rounds for which to solve the characteristic.
                                         If not provided or 0, the number of rounds will be set to the
                                         number of rounds of the solver.
            num_sols (int, optional): The number of solutions to return.
            display_paths (bool, optional): Whether to display the paths of the solutions.

        Returns:
            list: A list of tuples. Each tuple contains the input masks, the output masks, and the
                  calculated bias for a solution.
        """
    if num_rounds == 0:
        num_rounds = self.num_rounds
    else:
        # cap to initialized struct
        num_rounds = min(self.num_rounds, num_rounds)
    while len(self.solver.objectives()):
        self.solver.pop()  # remove any previous include/exclude block constraints
    self.solver.push()  # set this as the checkpoint
    # specify which blocks to definitely include in the characteristic
    for i in include_blocks:
        self.solver.add(self.inps[num_rounds - 1][i] != 0)
    # specify which blocks to definitely exclude in the characteristic
    for i in exclude_blocks:
        self.solver.add(self.inps[num_rounds - 1][i] == 0)
    # print(include_blocks, exclude_blocks)
    # if a block is neither in include_blocks or exclude_blocks
    # the solver finds the best path which may or may not set it to active
    self.solver.maximize(self.objectives['reduced'](num_rounds))
    solutions = self.get_masks(num_rounds, num_sols, display_paths)
    self.solutions[(tuple(sorted(include_blocks)),
                    tuple(sorted(exclude_blocks)))].extend(solutions)
    return [(inp_masks[0], inp_masks[-1], calc_bias)
            for inp_masks, _, calc_bias, _ in solutions]
class DifferentialCryptanalysis (sbox, pbox, num_rounds)

Class for performing differential cryptanalysis.

Methods: - init: Initialize the differential cryptanalysis algorithm. - find_keybits: Find the key bits using differential cryptanalysis. - find_last_roundkey: Find the last round key using differential cryptanalysis. - generate_encryption_pairs: Generate encryption pairs for differential cryptanalysis.

Initialize the DifferentialCryptanalysis algorithm.

Parameters: - sbox: The substitution box used in the SPN - pbox: The permutation box used in the SPN - num_rounds: The number of rounds in the SPN

Notes: - This method is called when creating an instance of the DifferentialCryptanalysis class.

Expand source code
class DifferentialCryptanalysis(Cryptanalysis):
    """
    Class for performing differential cryptanalysis.

    Methods:
    - __init__: Initialize the differential cryptanalysis algorithm.
    - find_keybits: Find the key bits using differential cryptanalysis.
    - find_last_roundkey: Find the last round key using differential cryptanalysis.
    - generate_encryption_pairs: Generate encryption pairs for differential cryptanalysis.
    """

    def __init__(self, sbox, pbox, num_rounds):
        """
        Initialize the DifferentialCryptanalysis algorithm.

        Parameters:
        - sbox: The substitution box used in the SPN
        - pbox: The permutation box used in the SPN
        - num_rounds: The number of rounds in the SPN

        Notes:
        - This method is called when creating an instance of the DifferentialCryptanalysis class.
        """
        super().__init__(sbox, pbox, num_rounds, 'differential')

    def find_keybits(self, out_mask, ct_pairs, known_keyblocks=()):
        """Finds the key bits based on the output mask and ciphertext pairs.

        This method overrides the abstract `find_keybits` method in the `Cryptanalysis` class.
        It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input.
        It implements the logic to find the key bits based on the provided parameters.

        Args:
            out_mask (int): The output mask for the difference in encrypted pairs.
            ct_pairs (list): A list of ciphertext pairs used for analysis.
            known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.

        Returns:
            int: A value representing the most likely key bits
        """
        out_blocks = self.int_to_list(out_mask)
        active_blocks = [i for i, v in enumerate(out_blocks) if v and i not in known_keyblocks]
        key_diffcounts = Counter()
        for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
            key = [0] * self.num_sbox
            for i, v in zip(active_blocks, klst):
                key[i] = v
            key = self.list_to_int(key)
            for ct1, ct2 in ct_pairs:
                diff = self.dec_partial_last_noperm(
                    ct1, [key]) ^ self.dec_partial_last_noperm(
                    ct2, [key])
                diff = self.int_to_list(diff)
                key_diffcounts[key] += all(out_blocks[i] == diff[i] for i in active_blocks)
                # key_diffcounts[key] += all(i==j for i,j in zip(out_blocks,diff))
        topn = key_diffcounts.most_common(self.box_size)
        for i, v in topn:
            print(self.int_to_list(i), v)
        return topn[0]

    def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000):
        """Finds the last round key based on output masks.

        This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class.
        It takes a list of output masks, a cutoff value, and a multiple value as input.
        It implements the logic to find the last round key based on the output masks and the specified parameters.

        Args:
            outmasks (list): A list of tuples of (input, output masks, bias) for which
                                the last round key needs to be found.
            cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                    called from oracle in determining the last round key. Defaults to 10000.
            multiple (int, optional): The multiple indicating the size of the batch of values to be
                                encrypted at once used for generating encryption pairs. Defaults to 1000.

        Returns:
            list: The last round key determined based on the output masks.
        """
        final_key = [None] * self.num_sbox
        all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple)
        for pt_ct_pairs, (_, out_mask, _) in zip(all_pt_ct_pairs, outmasks):
            ct_pairs = [i[1] for i in pt_ct_pairs]
            # print("out mask",self.int_to_list(out_mask))
            k = self.find_keybits(out_mask, ct_pairs, [
                i for i, v in enumerate(final_key) if v is not None])
            kr = self.int_to_list(k[0])
            print(kr)
            for i, v in enumerate(self.int_to_list(out_mask)):
                if v and final_key[i] is None:
                    final_key[i] = kr[i]
            print(final_key)
            print()
        return final_key

    def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000):
        """Generates encryption pairs for a set of output masks.

        This method overrides the abstract `generate_encryption_pairs` method in the `Cryptanalysis` class.
        It takes a list of output masks, a cutoff value, and a multiple value as input.
        It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

        Args:
            outmasks (list): A list of tuples of (input_diff_mask, output_diff_mask and bias)
                            for which encryption pairs need to be generated.
            cutoff (int, optional): The cutoff value of the max number of encryptions invoked. Defaults to 10000.
            multiple (int, optional): The multiple indicating the size of the batch of values to be
                                encrypted at once used for generating encryption pairs. Defaults to 1000.

        Returns:
            list: A list of plaintext-ciphertext pairs for each output mask.
        """
        all_pt_pairs = []
        for inp_mask, _, bias in outmasks:
            pt_pairs = []
            new_encs = {}  # new encryptions + seen encryptions
            threshold = min(100 * int(1 / bias), cutoff)
            # first pass, look for already existing pairs
            for i in self.encryptions:
                if len(pt_pairs) >= threshold:
                    break
                if i in new_encs:
                    # already added the pair i.e i^inp_mask
                    continue
                if i ^ inp_mask in self.encryptions:
                    new_encs[i] = self.encryptions[i]
                    new_encs[i ^ inp_mask] = self.encryptions[i ^ inp_mask]
                    pt_pairs.append((i, i ^ inp_mask))
            for i in set(self.encryptions) - set(new_encs):
                if len(pt_pairs) >= threshold:
                    break
                # only add if we have exhausted our already encrypted pairs
                new_encs[i] = self.encryptions[i]
                # new_encs[i^inp_mask] = self.encrypt(i^inp_mask)
                new_encs[i ^ inp_mask] = None  # marked to be encrypted
                pt_pairs.append((i, i ^ inp_mask))
            self.encryptions.update(new_encs)
            while len(pt_pairs) < threshold:
                r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1)
                if r in self.encryptions or r ^ inp_mask in self.encryptions:
                    continue
                self.encryptions[r] = None
                self.encryptions[r ^ inp_mask] = None
                pt_pairs.append((r, r ^ inp_mask))
            all_pt_pairs.append(pt_pairs)

        self.update_encryptions(multiple=multiple)

        all_pt_ct_pairs = []
        for pt_pairs in all_pt_pairs:
            pt_ct_pairs = []
            for (p1, p2) in pt_pairs:
                pt_ct_pairs.append(
                    ((p1, p2), (self.encryptions[p1], self.encryptions[p2])))
            all_pt_ct_pairs.append(pt_ct_pairs)
        return all_pt_ct_pairs

Ancestors

Methods

def find_keybits(self, out_mask, ct_pairs, known_keyblocks=())

Finds the key bits based on the output mask and ciphertext pairs.

This method overrides the abstract find_keybits method in the Cryptanalysis class. It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input. It implements the logic to find the key bits based on the provided parameters.

Args

out_mask : int
The output mask for the difference in encrypted pairs.
ct_pairs : list
A list of ciphertext pairs used for analysis.
known_keyblocks : list, optional
A list of known key blocks. Defaults to an empty list.

Returns

int
A value representing the most likely key bits
Expand source code
def find_keybits(self, out_mask, ct_pairs, known_keyblocks=()):
    """Finds the key bits based on the output mask and ciphertext pairs.

    This method overrides the abstract `find_keybits` method in the `Cryptanalysis` class.
    It takes an output mask, a list of ciphertext pairs, and an optional list of known key blocks as input.
    It implements the logic to find the key bits based on the provided parameters.

    Args:
        out_mask (int): The output mask for the difference in encrypted pairs.
        ct_pairs (list): A list of ciphertext pairs used for analysis.
        known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.

    Returns:
        int: A value representing the most likely key bits
    """
    out_blocks = self.int_to_list(out_mask)
    active_blocks = [i for i, v in enumerate(out_blocks) if v and i not in known_keyblocks]
    key_diffcounts = Counter()
    for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
        key = [0] * self.num_sbox
        for i, v in zip(active_blocks, klst):
            key[i] = v
        key = self.list_to_int(key)
        for ct1, ct2 in ct_pairs:
            diff = self.dec_partial_last_noperm(
                ct1, [key]) ^ self.dec_partial_last_noperm(
                ct2, [key])
            diff = self.int_to_list(diff)
            key_diffcounts[key] += all(out_blocks[i] == diff[i] for i in active_blocks)
            # key_diffcounts[key] += all(i==j for i,j in zip(out_blocks,diff))
    topn = key_diffcounts.most_common(self.box_size)
    for i, v in topn:
        print(self.int_to_list(i), v)
    return topn[0]
def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000)

Finds the last round key based on output masks.

This method overrides the abstract find_last_roundkey method in the Cryptanalysis class. It takes a list of output masks, a cutoff value, and a multiple value as input. It implements the logic to find the last round key based on the output masks and the specified parameters.

Args

outmasks : list
A list of tuples of (input, output masks, bias) for which the last round key needs to be found.
cutoff : int, optional
The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000.
multiple : int, optional
The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.

Returns

list
The last round key determined based on the output masks.
Expand source code
def find_last_roundkey(self, outmasks, cutoff=10000, multiple=1000):
    """Finds the last round key based on output masks.

    This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class.
    It takes a list of output masks, a cutoff value, and a multiple value as input.
    It implements the logic to find the last round key based on the output masks and the specified parameters.

    Args:
        outmasks (list): A list of tuples of (input, output masks, bias) for which
                            the last round key needs to be found.
        cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                called from oracle in determining the last round key. Defaults to 10000.
        multiple (int, optional): The multiple indicating the size of the batch of values to be
                            encrypted at once used for generating encryption pairs. Defaults to 1000.

    Returns:
        list: The last round key determined based on the output masks.
    """
    final_key = [None] * self.num_sbox
    all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple)
    for pt_ct_pairs, (_, out_mask, _) in zip(all_pt_ct_pairs, outmasks):
        ct_pairs = [i[1] for i in pt_ct_pairs]
        # print("out mask",self.int_to_list(out_mask))
        k = self.find_keybits(out_mask, ct_pairs, [
            i for i, v in enumerate(final_key) if v is not None])
        kr = self.int_to_list(k[0])
        print(kr)
        for i, v in enumerate(self.int_to_list(out_mask)):
            if v and final_key[i] is None:
                final_key[i] = kr[i]
        print(final_key)
        print()
    return final_key
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000)

Generates encryption pairs for a set of output masks.

This method overrides the abstract generate_encryption_pairs method in the Cryptanalysis class. It takes a list of output masks, a cutoff value, and a multiple value as input. It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

Args

outmasks : list
A list of tuples of (input_diff_mask, output_diff_mask and bias) for which encryption pairs need to be generated.
cutoff : int, optional
The cutoff value of the max number of encryptions invoked. Defaults to 10000.
multiple : int, optional
The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.

Returns

list
A list of plaintext-ciphertext pairs for each output mask.
Expand source code
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000):
    """Generates encryption pairs for a set of output masks.

    This method overrides the abstract `generate_encryption_pairs` method in the `Cryptanalysis` class.
    It takes a list of output masks, a cutoff value, and a multiple value as input.
    It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

    Args:
        outmasks (list): A list of tuples of (input_diff_mask, output_diff_mask and bias)
                        for which encryption pairs need to be generated.
        cutoff (int, optional): The cutoff value of the max number of encryptions invoked. Defaults to 10000.
        multiple (int, optional): The multiple indicating the size of the batch of values to be
                            encrypted at once used for generating encryption pairs. Defaults to 1000.

    Returns:
        list: A list of plaintext-ciphertext pairs for each output mask.
    """
    all_pt_pairs = []
    for inp_mask, _, bias in outmasks:
        pt_pairs = []
        new_encs = {}  # new encryptions + seen encryptions
        threshold = min(100 * int(1 / bias), cutoff)
        # first pass, look for already existing pairs
        for i in self.encryptions:
            if len(pt_pairs) >= threshold:
                break
            if i in new_encs:
                # already added the pair i.e i^inp_mask
                continue
            if i ^ inp_mask in self.encryptions:
                new_encs[i] = self.encryptions[i]
                new_encs[i ^ inp_mask] = self.encryptions[i ^ inp_mask]
                pt_pairs.append((i, i ^ inp_mask))
        for i in set(self.encryptions) - set(new_encs):
            if len(pt_pairs) >= threshold:
                break
            # only add if we have exhausted our already encrypted pairs
            new_encs[i] = self.encryptions[i]
            # new_encs[i^inp_mask] = self.encrypt(i^inp_mask)
            new_encs[i ^ inp_mask] = None  # marked to be encrypted
            pt_pairs.append((i, i ^ inp_mask))
        self.encryptions.update(new_encs)
        while len(pt_pairs) < threshold:
            r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1)
            if r in self.encryptions or r ^ inp_mask in self.encryptions:
                continue
            self.encryptions[r] = None
            self.encryptions[r ^ inp_mask] = None
            pt_pairs.append((r, r ^ inp_mask))
        all_pt_pairs.append(pt_pairs)

    self.update_encryptions(multiple=multiple)

    all_pt_ct_pairs = []
    for pt_pairs in all_pt_pairs:
        pt_ct_pairs = []
        for (p1, p2) in pt_pairs:
            pt_ct_pairs.append(
                ((p1, p2), (self.encryptions[p1], self.encryptions[p2])))
        all_pt_ct_pairs.append(pt_ct_pairs)
    return all_pt_ct_pairs

Inherited members

class LinearCryptanalysis (sbox, pbox, num_rounds)

Class for performing linear cryptanalysis.

Methods: - init: Initialize the linear cryptanalysis algorithm. - find_keybits: Find the key bits using linear cryptanalysis. - _find_keybits_multimasks: Expeimental method utilising mutiple linear characteristics on the same block to find key bits - find_last_roundkey: Find the last round key using linear cryptanalysis. - generate_encryption_pairs: Generate encryption pairs for linear cryptanalysis.

Initialize the LinearCryptanalysis algorithm.

Parameters: - sbox: The substitution box used in the SPN - pbox: The permutation box used in the SPN - num_rounds: The number of rounds in the SPN

Notes: - This method is called when creating an instance of the LinearCryptanalysis class.

Expand source code
class LinearCryptanalysis(Cryptanalysis):
    """
    Class for performing linear cryptanalysis.

    Methods:
    - __init__: Initialize the linear cryptanalysis algorithm.
    - find_keybits: Find the key bits using linear cryptanalysis.
    - _find_keybits_multimasks: Expeimental method utilising mutiple linear characteristics on the same block to find key bits
    - find_last_roundkey: Find the last round key using linear cryptanalysis.
    - generate_encryption_pairs: Generate encryption pairs for linear cryptanalysis.
    """

    def __init__(self, sbox, pbox, num_rounds):
        """
        Initialize the LinearCryptanalysis algorithm.

        Parameters:
        - sbox: The substitution box used in the SPN
        - pbox: The permutation box used in the SPN
        - num_rounds: The number of rounds in the SPN

        Notes:
        - This method is called when creating an instance of the LinearCryptanalysis class.
        """
        super().__init__(sbox, pbox, num_rounds, 'linear')

    def _find_keybits_multimasks(self, in_out_masks, ptct_pairs, known_keyblocks=()):
        """Finds the key bits based on multiple input-output masks for a given block and plaintext-ciphertext pairs.

        This method takes a list of input-output masks, a list of plaintext-ciphertext pairs,
        and an optional list of known key blocks as input.
        Note that this method is experimental to try out using more than one linear characteristic ending on a block.

        Args:
            in_out_masks (list): A list of input-output masks for key search.
            ptct_pairs (list): A list of plaintext-ciphertext pairs used for analysis.
            known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.

        Returns:
            list: A list of Counter objects containing the key bit differences for each active block.
        """
        key_diffcounts = [Counter() for i in range(self.num_sbox)]
        for in_mask, out_mask, _ in tqdm(in_out_masks):
            out_blocks = self.int_to_list(out_mask)
            active_blocks = [i for i, v in enumerate(
                out_blocks) if v and i not in known_keyblocks]
            key_diffcount_curr = Counter()
            for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
                key = [0] * self.num_sbox
                for i, v in zip(active_blocks, klst):
                    key[i] = v
                key = self.list_to_int(key)
                for pt, ct in ptct_pairs:
                    ct_last = self.dec_partial_last_noperm(ct, [key])
                    key_diffcount_curr[key] += parity((pt & in_mask) ^ (ct_last & out_mask))
            for i in key_diffcount_curr:
                count = abs(key_diffcount_curr[i] - len(ptct_pairs) // 2)
                key_list = self.int_to_list(i)
                for j in active_blocks:
                    key_diffcounts[j][key_list[j]] += count
            for j in active_blocks:
                topn = key_diffcounts[j].most_common(self.box_size)
                for i, v in topn:
                    print(i, v)
        return key_diffcounts

    def find_keybits(self, in_mask, out_mask, ptct_pairs, known_keyblocks=()):
        """Finds the key bits based on an input mask, an output mask, and plaintext-ciphertext pairs.

        This method takes an input mask, an output mask, a list of plaintext-ciphertext
        pairs, and an optional list of known key blocks as input.
        It implements the logic to find the key bits based on the provided parameters.

        Args:
            in_mask (int): The input mask for the key search.
            out_mask (int): The output mask for the key search.
            ptct_pairs (list): A list of tuples of plaintext-ciphertext pairs used for analysis.
            known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.

        Returns:
            int: A integer representing the most probable keybits
        """

        out_blocks = self.int_to_list(out_mask)
        active_blocks = [i for i, v in enumerate(
            out_blocks) if v and i not in known_keyblocks]
        key_diffcounts = Counter()
        for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
            key = [0] * self.num_sbox
            for i, v in zip(active_blocks, klst):
                key[i] = v
            key = self.list_to_int(key)
            for pt, ct in ptct_pairs:
                ct_last = self.dec_partial_last_noperm(ct, [key])
                key_diffcounts[key] += parity((pt & in_mask) ^ (ct_last & out_mask))
        for i in key_diffcounts:
            key_diffcounts[i] = abs(key_diffcounts[i] - len(ptct_pairs) // 2)
        topn = key_diffcounts.most_common(self.box_size)
        for i, v in topn:
            print(self.int_to_list(i), v)
        return topn[0]

    def find_last_roundkey(self, outmasks, cutoff=50000, multiple=1000):
        """Finds the last round key based on output masks.

        This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class.
        It takes a list of output masks, a cutoff value, and a multiple value as input.
        It implements the logic to find the last round key based on the output
        masks and the specified parameters.

        Args:
            outmasks (list): A list of tuples of (input, output masks, bias) for which
                                the last round key needs to be found.
            cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                    called from oracle in determining the last round key. Defaults to 10000.
            multiple (int, optional): The multiple indicating the size of the batch of values to be
                                encrypted at once used for generating encryption pairs. Defaults to 1000.

        Returns:
            list: The last round key determined based on the output masks.
        """
        final_key = [None] * self.num_sbox
        all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple)
        for ptct_pairs, (inp_mask, out_mask, _) in zip(all_pt_ct_pairs, outmasks):
            k = self.find_keybits(inp_mask, out_mask, ptct_pairs, [
                i for i, v in enumerate(final_key) if v is not None])
            kr = self.int_to_list(k[0])
            print(kr)
            for i, v in enumerate(self.int_to_list(out_mask)):
                if v and final_key[i] is None:
                    final_key[i] = kr[i]
            print(final_key)
            print()
        return final_key

    def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000):
        """Generates encryption pairs for a set of output masks.

        This method takes a list of output masks, a cutoff value, and a multiple value as input.
        It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

        Args:
            outmasks (list): A list of tuples of (input, output masks, bias) for which
                                the last round key needs to be found.
            cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                    called from oracle in determining the last round key. Defaults to 10000.
            multiple (int, optional): The multiple indicating the size of the batch of values to be
                                encrypted at once used for generating encryption pairs. Defaults to 1000.
        Returns:
            list: A list of plaintext-ciphertext pairs for each output mask.
        """
        max_threshold = max(100 * int(1 / (bias * bias)) for _, _, bias in outmasks)
        threshold = min(cutoff, max_threshold)
        all_pt = list(self.encryptions)[:threshold]
        while len(all_pt) < threshold:
            r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1)
            if r in self.encryptions:
                continue
            self.encryptions[r] = None
            all_pt.append(r)
        self.update_encryptions(multiple=multiple)
        all_ptct = [(i, self.encryptions[i]) for i in all_pt]
        return [all_ptct] * len(outmasks)

Ancestors

Methods

def find_keybits(self, in_mask, out_mask, ptct_pairs, known_keyblocks=())

Finds the key bits based on an input mask, an output mask, and plaintext-ciphertext pairs.

This method takes an input mask, an output mask, a list of plaintext-ciphertext pairs, and an optional list of known key blocks as input. It implements the logic to find the key bits based on the provided parameters.

Args

in_mask : int
The input mask for the key search.
out_mask : int
The output mask for the key search.
ptct_pairs : list
A list of tuples of plaintext-ciphertext pairs used for analysis.
known_keyblocks : list, optional
A list of known key blocks. Defaults to an empty list.

Returns

int
A integer representing the most probable keybits
Expand source code
def find_keybits(self, in_mask, out_mask, ptct_pairs, known_keyblocks=()):
    """Finds the key bits based on an input mask, an output mask, and plaintext-ciphertext pairs.

    This method takes an input mask, an output mask, a list of plaintext-ciphertext
    pairs, and an optional list of known key blocks as input.
    It implements the logic to find the key bits based on the provided parameters.

    Args:
        in_mask (int): The input mask for the key search.
        out_mask (int): The output mask for the key search.
        ptct_pairs (list): A list of tuples of plaintext-ciphertext pairs used for analysis.
        known_keyblocks (list, optional): A list of known key blocks. Defaults to an empty list.

    Returns:
        int: A integer representing the most probable keybits
    """

    out_blocks = self.int_to_list(out_mask)
    active_blocks = [i for i, v in enumerate(
        out_blocks) if v and i not in known_keyblocks]
    key_diffcounts = Counter()
    for klst in product(range(len(self.sbox)), repeat=len(active_blocks)):
        key = [0] * self.num_sbox
        for i, v in zip(active_blocks, klst):
            key[i] = v
        key = self.list_to_int(key)
        for pt, ct in ptct_pairs:
            ct_last = self.dec_partial_last_noperm(ct, [key])
            key_diffcounts[key] += parity((pt & in_mask) ^ (ct_last & out_mask))
    for i in key_diffcounts:
        key_diffcounts[i] = abs(key_diffcounts[i] - len(ptct_pairs) // 2)
    topn = key_diffcounts.most_common(self.box_size)
    for i, v in topn:
        print(self.int_to_list(i), v)
    return topn[0]
def find_last_roundkey(self, outmasks, cutoff=50000, multiple=1000)

Finds the last round key based on output masks.

This method overrides the abstract find_last_roundkey method in the Cryptanalysis class. It takes a list of output masks, a cutoff value, and a multiple value as input. It implements the logic to find the last round key based on the output masks and the specified parameters.

Args

outmasks : list
A list of tuples of (input, output masks, bias) for which the last round key needs to be found.
cutoff : int, optional
The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000.
multiple : int, optional
The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.

Returns

list
The last round key determined based on the output masks.
Expand source code
def find_last_roundkey(self, outmasks, cutoff=50000, multiple=1000):
    """Finds the last round key based on output masks.

    This method overrides the abstract `find_last_roundkey` method in the `Cryptanalysis` class.
    It takes a list of output masks, a cutoff value, and a multiple value as input.
    It implements the logic to find the last round key based on the output
    masks and the specified parameters.

    Args:
        outmasks (list): A list of tuples of (input, output masks, bias) for which
                            the last round key needs to be found.
        cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                called from oracle in determining the last round key. Defaults to 10000.
        multiple (int, optional): The multiple indicating the size of the batch of values to be
                            encrypted at once used for generating encryption pairs. Defaults to 1000.

    Returns:
        list: The last round key determined based on the output masks.
    """
    final_key = [None] * self.num_sbox
    all_pt_ct_pairs = self.generate_encryption_pairs(outmasks, cutoff, multiple=multiple)
    for ptct_pairs, (inp_mask, out_mask, _) in zip(all_pt_ct_pairs, outmasks):
        k = self.find_keybits(inp_mask, out_mask, ptct_pairs, [
            i for i, v in enumerate(final_key) if v is not None])
        kr = self.int_to_list(k[0])
        print(kr)
        for i, v in enumerate(self.int_to_list(out_mask)):
            if v and final_key[i] is None:
                final_key[i] = kr[i]
        print(final_key)
        print()
    return final_key
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000)

Generates encryption pairs for a set of output masks.

This method takes a list of output masks, a cutoff value, and a multiple value as input. It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

Args

outmasks : list
A list of tuples of (input, output masks, bias) for which the last round key needs to be found.
cutoff : int, optional
The cutoff value used for the maximum number of encryptions called from oracle in determining the last round key. Defaults to 10000.
multiple : int, optional
The multiple indicating the size of the batch of values to be encrypted at once used for generating encryption pairs. Defaults to 1000.

Returns

list
A list of plaintext-ciphertext pairs for each output mask.
Expand source code
def generate_encryption_pairs(self, outmasks, cutoff=10000, multiple=1000):
    """Generates encryption pairs for a set of output masks.

    This method takes a list of output masks, a cutoff value, and a multiple value as input.
    It generates plaintext-ciphertext pairs for each output mask based on the specified parameters.

    Args:
        outmasks (list): A list of tuples of (input, output masks, bias) for which
                            the last round key needs to be found.
        cutoff (int, optional): The cutoff value used for the maximum number of encryptions
                                called from oracle in determining the last round key. Defaults to 10000.
        multiple (int, optional): The multiple indicating the size of the batch of values to be
                            encrypted at once used for generating encryption pairs. Defaults to 1000.
    Returns:
        list: A list of plaintext-ciphertext pairs for each output mask.
    """
    max_threshold = max(100 * int(1 / (bias * bias)) for _, _, bias in outmasks)
    threshold = min(cutoff, max_threshold)
    all_pt = list(self.encryptions)[:threshold]
    while len(all_pt) < threshold:
        r = random.randint(0, 2**(self.num_sbox * self.box_size) - 1)
        if r in self.encryptions:
            continue
        self.encryptions[r] = None
        all_pt.append(r)
    self.update_encryptions(multiple=multiple)
    all_ptct = [(i, self.encryptions[i]) for i in all_pt]
    return [all_ptct] * len(outmasks)

Inherited members

class SPN (sbox, pbox, key, rounds, implementation=0)

Class representing the SPN (Substitution-Permutation Network) encryption algorithm.

Methods

perm(inp) Apply the P-box permutation on the input. inv_perm(inp) Apply the inverse P-box permutation on the input. sub(inp) Apply the S-box substitution on the input. inv_sub(inp) Apply the inverse S-box substitution on the input. int_to_list(inp) Convert a len(pbox)-sized integer to a list of S-box sized integers. list_to_int(lst) Convert a list of S-box sized integers to a len(pbox)-sized integer. expand_key(key, rounds) Derive round keys deterministically from the given key. _enc_last_noperm(pt) Encrypt plaintext using the SPN, where the last round doesn't contain the permute operation. _enc_last_withperm(ct) Encrypt plaintext using the SPN, where the last round contains the permute operation. _dec_last_noperm(ct) Decrypt ciphertext using the SPN, where the last round doesn't contain the permute operation. _dec_last_withperm(ct) Decrypt ciphertext using the SPN, where the last round contains the permute operation.

Initialize the SPN class with the provided parameters.

Parameters

sbox : list of int
List of integers representing the S-box.
pbox : list of int
List of integers representing the P-box.
key : list of int or bytes or bytearray
List of integers, bytes, or bytearray representing the key. LSB block_size bits will be used.
rounds : int
Number of rounds for the SPN.
implementation : int, optional
Implementation option. Default is 0. 0: Last round doesn't contain the permute operation. 1: Last round contains the permute operation.
Expand source code
class SPN:
    """
    Class representing the SPN (Substitution-Permutation Network) encryption algorithm.

    Methods
    -------
    perm(inp)
        Apply the P-box permutation on the input.
    inv_perm(inp)
        Apply the inverse P-box permutation on the input.
    sub(inp)
        Apply the S-box substitution on the input.
    inv_sub(inp)
        Apply the inverse S-box substitution on the input.
    int_to_list(inp)
        Convert a len(pbox)-sized integer to a list of S-box sized integers.
    list_to_int(lst)
        Convert a list of S-box sized integers to a len(pbox)-sized integer.
    expand_key(key, rounds)
        Derive round keys deterministically from the given key.
    _enc_last_noperm(pt)
        Encrypt plaintext using the SPN, where the last round doesn't contain the permute operation.
    _enc_last_withperm(ct)
        Encrypt plaintext using the SPN, where the last round contains the permute operation.
    _dec_last_noperm(ct)
        Decrypt ciphertext using the SPN, where the last round doesn't contain the permute operation.
    _dec_last_withperm(ct)
        Decrypt ciphertext using the SPN, where the last round contains the permute operation.
    """

    def __init__(self, sbox, pbox, key, rounds, implementation=0):
        """
        Initialize the SPN class with the provided parameters.

        Parameters
        ----------
        sbox : list of int
            List of integers representing the S-box.

        pbox : list of int
            List of integers representing the P-box.

        key : list of int or bytes or bytearray
            List of integers, bytes, or bytearray representing the key.
            LSB block_size bits will be used.

        rounds : int
            Number of rounds for the SPN.

        implementation : int, optional
            Implementation option. Default is 0.
            0: Last round doesn't contain the permute operation.
            1: Last round contains the permute operation.
        """
        self.sbox = sbox
        self.pbox = pbox
        self.sinv = [sbox.index(i) for i in range(len(sbox))]
        self.pinv = [pbox.index(i) for i in range(len(pbox))]
        self.block_size = len(pbox)
        self.box_size = int(log2(len(sbox)))
        self.num_sbox = len(pbox) // self.box_size
        self.rounds = rounds
        self.round_keys = self.expand_key(key, rounds)
        if implementation == 0:
            self.encrypt = self._enc_last_noperm
            self.decrypt = self._dec_last_noperm
        else:
            self.encrypt = self._enc_last_withperm
            self.decrypt = self._dec_last_withperm

    def perm(self, inp: int) -> int:
        """
        Apply the P-box permutation on the input.

        Parameters
        ----------
        inp : int
            The input value to apply the P-box permutation on.

        Returns
        -------
        int
            The permuted value after applying the P-box.
        """
        ct = 0
        for i, v in enumerate(self.pbox):
            ct |= (inp >> (self.block_size - 1 - i) & 1) << (
                self.block_size - 1 - v)
        return ct

    def inv_perm(self, inp: int) -> int:
        """
        Apply the inverse P-box permutation on the input.

        Parameters
        ----------
        inp : int
            The input value to apply the inverse P-box permutation on.

        Returns
        -------
        int
            The permuted value after applying the inverse P-box.
        """
        ct = 0
        for i, v in enumerate(self.pinv):
            ct |= (inp >> (self.block_size - 1 - i) & 1) << (
                self.block_size - 1 - v)
        return ct

    def sub(self, inp: int) -> int:
        """
        Apply the S-box substitution on the input.

        Parameters
        ----------
        inp : int
            The input value to apply the S-box substitution on.

        Returns
        -------
        int
            The substituted value after applying the S-box.
        """
        ct, bs = 0, self.box_size
        for i in range(self.num_sbox):
            ct |= self.sbox[(inp >> (i * bs)) & ((1 << bs) - 1)] << (bs * i)
        return ct

    def inv_sub(self, inp: int) -> int:
        """
        Apply the inverse S-box substitution on the input.

        Parameters
        ----------
        inp : int
            The input value to apply the inverse S-box substitution on.

        Returns
        -------
        int
            The substituted value after applying the inverse S-box.
        """
        ct, bs = 0, self.box_size
        for i in range(self.num_sbox):
            ct |= self.sinv[(inp >> (i * bs)) & ((1 << bs) - 1)] << (bs * i)
        return ct

    def int_to_list(self, inp):
        """
        Convert a len(pbox)-sized integer to a list of S-box sized integers.

        Parameters
        ----------
        inp : int
            An integer representing a len(pbox)-sized input.

        Returns
        -------
        list of int
            A list of integers, each representing an S-box sized input.
        """
        bs = self.box_size
        return [(inp >> (i * bs)) & ((1 << bs) - 1)
                for i in range(self.num_sbox - 1, -1, -1)]

    def list_to_int(self, lst):
        """
        Convert a list of S-box sized integers to a len(pbox)-sized integer.

        Parameters
        ----------
        lst : list of int
            A list of integers, each representing an S-box sized input.

        Returns
        -------
        int
            An integer representing the combined input as a len(pbox)-sized integer.
        """
        res = 0
        for i, v in enumerate(lst[::-1]):
            res |= v << (i * self.box_size)
        return res

    def expand_key(self, key, rounds):
        """
        Derive round keys deterministically from the given key.

        Parameters
        ----------
        key : list of int or bytes or bytearray
            A list of integers, bytes, or bytearray representing the key.
        rounds : int
            The number of rounds for the SPN.

        Returns
        -------
        list of int
            A list of integers representing the derived round keys.
        """
        if isinstance(key, list):
            key = self.list_to_int(key)
        elif isinstance(key, (bytes, bytearray)):
            key = int.from_bytes(key, 'big')
        block_mask = (1 << self.block_size) - 1
        key = key & block_mask
        keys = [key]
        for _ in range(rounds):
            keys.append(self.sub(rotate_left(
                keys[-1], self.box_size + 1, self.block_size)))
        return keys

    def _enc_last_noperm(self, pt: int) -> int:
        """
        Encrypt plaintext using the SPN, where the last round doesn't contain the permute operation.

        Parameters
        ----------
        pt : int
            The plaintext input to be encrypted.

        Returns
        -------
        int
            The ciphertext after encryption.
        """
        ct = pt ^ self.round_keys[0]
        for round_key in self.round_keys[1:-1]:
            ct = self.sub(ct)
            ct = self.perm(ct)
            ct ^= round_key
        ct = self.sub(ct)
        return ct ^ self.round_keys[-1]

    def _enc_last_withperm(self, ct: int) -> int:
        """
        Encrypt plaintext using the SPN, where the last round contains the permute operation.
        Note, the last permutation provides no additional security.

        Parameters
        ----------
        ct : int
            The plaintext input to be encrypted.

        Returns
        -------
        int
            The ciphertext after encryption.
        """
        for round_key in self.round_keys[:-1]:
            ct ^= round_key
            ct = self.sub(ct)
            ct = self.perm(ct)
        return ct ^ self.round_keys[-1]

    def _dec_last_noperm(self, ct: int) -> int:
        """
        Decrypt ciphertext using the SPN, where the last round doesn't contain the permute operation.

        Parameters
        ----------
        ct : int
            The ciphertext input to be decrypted.

        Returns
        -------
        int
            The plaintext after decryption.
        """
        ct = ct ^ self.round_keys[-1]
        ct = self.inv_sub(ct)
        for rk in self.round_keys[-2:0:-1]:
            ct ^= rk
            ct = self.inv_perm(ct)
            ct = self.inv_sub(ct)
        return ct ^ self.round_keys[0]

    def _dec_last_withperm(self, ct: int) -> int:
        """
        Decrypt ciphertext using the SPN, where the last round contains the permute operation.

        Parameters
        ----------
        ct : int
            The ciphertext input to be decrypted.

        Returns
        -------
        int
            The plaintext after decryption.
        """
        ct = ct ^ self.round_keys[-1]
        for rk in self.round_keys[-2::-1]:
            ct = self.inv_perm(ct)
            ct = self.inv_sub(ct)
            ct ^= rk
        return ct

Subclasses

Methods

def expand_key(self, key, rounds)

Derive round keys deterministically from the given key.

Parameters

key : list of int or bytes or bytearray
A list of integers, bytes, or bytearray representing the key.
rounds : int
The number of rounds for the SPN.

Returns

list of int
A list of integers representing the derived round keys.
Expand source code
def expand_key(self, key, rounds):
    """
    Derive round keys deterministically from the given key.

    Parameters
    ----------
    key : list of int or bytes or bytearray
        A list of integers, bytes, or bytearray representing the key.
    rounds : int
        The number of rounds for the SPN.

    Returns
    -------
    list of int
        A list of integers representing the derived round keys.
    """
    if isinstance(key, list):
        key = self.list_to_int(key)
    elif isinstance(key, (bytes, bytearray)):
        key = int.from_bytes(key, 'big')
    block_mask = (1 << self.block_size) - 1
    key = key & block_mask
    keys = [key]
    for _ in range(rounds):
        keys.append(self.sub(rotate_left(
            keys[-1], self.box_size + 1, self.block_size)))
    return keys
def int_to_list(self, inp)

Convert a len(pbox)-sized integer to a list of S-box sized integers.

Parameters

inp : int
An integer representing a len(pbox)-sized input.

Returns

list of int
A list of integers, each representing an S-box sized input.
Expand source code
def int_to_list(self, inp):
    """
    Convert a len(pbox)-sized integer to a list of S-box sized integers.

    Parameters
    ----------
    inp : int
        An integer representing a len(pbox)-sized input.

    Returns
    -------
    list of int
        A list of integers, each representing an S-box sized input.
    """
    bs = self.box_size
    return [(inp >> (i * bs)) & ((1 << bs) - 1)
            for i in range(self.num_sbox - 1, -1, -1)]
def inv_perm(self, inp: int) ‑> int

Apply the inverse P-box permutation on the input.

Parameters

inp : int
The input value to apply the inverse P-box permutation on.

Returns

int
The permuted value after applying the inverse P-box.
Expand source code
def inv_perm(self, inp: int) -> int:
    """
    Apply the inverse P-box permutation on the input.

    Parameters
    ----------
    inp : int
        The input value to apply the inverse P-box permutation on.

    Returns
    -------
    int
        The permuted value after applying the inverse P-box.
    """
    ct = 0
    for i, v in enumerate(self.pinv):
        ct |= (inp >> (self.block_size - 1 - i) & 1) << (
            self.block_size - 1 - v)
    return ct
def inv_sub(self, inp: int) ‑> int

Apply the inverse S-box substitution on the input.

Parameters

inp : int
The input value to apply the inverse S-box substitution on.

Returns

int
The substituted value after applying the inverse S-box.
Expand source code
def inv_sub(self, inp: int) -> int:
    """
    Apply the inverse S-box substitution on the input.

    Parameters
    ----------
    inp : int
        The input value to apply the inverse S-box substitution on.

    Returns
    -------
    int
        The substituted value after applying the inverse S-box.
    """
    ct, bs = 0, self.box_size
    for i in range(self.num_sbox):
        ct |= self.sinv[(inp >> (i * bs)) & ((1 << bs) - 1)] << (bs * i)
    return ct
def list_to_int(self, lst)

Convert a list of S-box sized integers to a len(pbox)-sized integer.

Parameters

lst : list of int
A list of integers, each representing an S-box sized input.

Returns

int
An integer representing the combined input as a len(pbox)-sized integer.
Expand source code
def list_to_int(self, lst):
    """
    Convert a list of S-box sized integers to a len(pbox)-sized integer.

    Parameters
    ----------
    lst : list of int
        A list of integers, each representing an S-box sized input.

    Returns
    -------
    int
        An integer representing the combined input as a len(pbox)-sized integer.
    """
    res = 0
    for i, v in enumerate(lst[::-1]):
        res |= v << (i * self.box_size)
    return res
def perm(self, inp: int) ‑> int

Apply the P-box permutation on the input.

Parameters

inp : int
The input value to apply the P-box permutation on.

Returns

int
The permuted value after applying the P-box.
Expand source code
def perm(self, inp: int) -> int:
    """
    Apply the P-box permutation on the input.

    Parameters
    ----------
    inp : int
        The input value to apply the P-box permutation on.

    Returns
    -------
    int
        The permuted value after applying the P-box.
    """
    ct = 0
    for i, v in enumerate(self.pbox):
        ct |= (inp >> (self.block_size - 1 - i) & 1) << (
            self.block_size - 1 - v)
    return ct
def sub(self, inp: int) ‑> int

Apply the S-box substitution on the input.

Parameters

inp : int
The input value to apply the S-box substitution on.

Returns

int
The substituted value after applying the S-box.
Expand source code
def sub(self, inp: int) -> int:
    """
    Apply the S-box substitution on the input.

    Parameters
    ----------
    inp : int
        The input value to apply the S-box substitution on.

    Returns
    -------
    int
        The substituted value after applying the S-box.
    """
    ct, bs = 0, self.box_size
    for i in range(self.num_sbox):
        ct |= self.sbox[(inp >> (i * bs)) & ((1 << bs) - 1)] << (bs * i)
    return ct