Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions claasp/cipher_modules/models/cp/cp_build_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dataclasses import dataclass, field


@dataclass(frozen=True)
class CpBuildContext:
"""Read-only configuration snapshot taken from ``MznModel`` at build time.

Generators receive this instead of the full model so they never
depend on mutable solver state.
"""

cipher: object = None
word_size: int = 0
data_type: str = "bool"
true_value: str = "true"
false_value: str = "false"
input_postfix: str = "x"
output_postfix: str = "y"
sat_or_milp: str = "sat"
float_and_lat_values: tuple = ()
bit_bindings_for_intermediate_output: dict = field(default_factory=dict)

@classmethod
def from_model(cls, model):
return cls(
cipher=model._cipher,
word_size=getattr(model, 'word_size', 0),
data_type=model.data_type,
true_value=model.true_value,
false_value=model.false_value,
input_postfix=getattr(model, 'input_postfix', 'x'),
output_postfix=getattr(model, 'output_postfix', 'y'),
sat_or_milp=getattr(model, 'sat_or_milp', 'sat'),
float_and_lat_values=tuple(getattr(model, '_float_and_lat_values', [])),
bit_bindings_for_intermediate_output=getattr(model, 'bit_bindings_for_intermediate_output', {}),
)
40 changes: 40 additions & 0 deletions claasp/cipher_modules/models/cp/cp_build_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass, field


@dataclass
class CpBuildState:
"""Mutable accumulator for state that CP constraint generators read and write.

Passed through the builder loop so generators never touch the model directly.
"""

next_probability_index: int = 0
shift_declaration_cache: list = field(default_factory=list)
component_probability_map: dict = field(default_factory=dict)
sbox_table_cache: list = field(default_factory=list)
intermediate_constraints_array: list = field(default_factory=list)
mzn_output_directives: list = field(default_factory=list)

@classmethod
def from_model(cls, model):
return cls(
next_probability_index=model.component_probability_index,
shift_declaration_cache=list(model.modadd_two_term_shift_cache),
component_probability_map=dict(model.component_and_probability),
sbox_table_cache=list(model.sbox_table_cache),
intermediate_constraints_array=list(getattr(model, 'intermediate_constraints_array', [])),
mzn_output_directives=list(getattr(model, 'mzn_output_directives', [])),
)

def apply_to_model(self, model):
model.component_probability_index = self.next_probability_index
model.modadd_two_term_shift_cache = self.shift_declaration_cache
model.component_and_probability = self.component_probability_map
model.sbox_table_cache = self.sbox_table_cache
model.intermediate_constraints_array = self.intermediate_constraints_array
model.mzn_output_directives = self.mzn_output_directives

def allocate_probability_index(self):
idx = self.next_probability_index
self.next_probability_index += 1
return idx
15 changes: 15 additions & 0 deletions claasp/cipher_modules/models/cp/cp_component_build_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass, field


@dataclass
class CpComponentBuildResult:
"""Value returned by decoupled CP constraint generators.

Bundles the declarations, constraints and optional metadata produced
by a single component generator call so the caller never needs to
inspect positional-tuple semantics.
"""

declarations: list = field(default_factory=list)
constraints: list = field(default_factory=list)
metadata: dict = field(default_factory=dict)
35 changes: 16 additions & 19 deletions claasp/cipher_modules/models/cp/mzn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
XOR_DIFFERENTIAL_LINEAR_OPTIMAL_SOLUTION,
)

from claasp.cipher_modules.models.cp.cp_build_context import CpBuildContext
from claasp.cipher_modules.models.cp.cp_build_state import CpBuildState
from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary

SOLVE_SATISFY = "solve satisfy;"
Expand Down Expand Up @@ -85,17 +87,17 @@ def __init__(self, cipher, sat_or_milp='sat'):
def initialise_model(self):
self._variables_list = []
self._model_constraints = []
self.c = 0
self.component_probability_index = 0
if self._cipher.is_spn():
for component in self._cipher.get_all_components():
if SBOX in component.type:
self.word_size = int(component.output_bit_size)
break
self._float_and_lat_values = []
self._probability = False
self.sbox_mant = []
self.mix_column_mant = []
self.modadd_twoterms_mant = []
self.sbox_table_cache = []
self.mix_column_declaration_cache = []
self.modadd_two_term_shift_cache = []
self.input_sbox = []
self.table_of_solutions_length = 0
self.list_of_xor_components = []
Expand Down Expand Up @@ -181,7 +183,9 @@ def add_solution_to_components_values_internal(
components_values[f"solution{solution_number}"][f"{component}"] = component_solution

def build_generic_cp_model_from_dictionary(self, component_and_model_types, fixed_variables=None):
variables = []
context = CpBuildContext.from_model(self)
state = CpBuildState.from_model(self)

self._variables_list = []
self._model_constraints = []
component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION]
Expand All @@ -196,21 +200,11 @@ def build_generic_cp_model_from_dictionary(self, component_and_model_types, fixe
print(f'{component.id} not yet implemented')
else:
cp_generic_propagation_constraints = getattr(component, model_type)
try:
result = cp_generic_propagation_constraints()
except TypeError:
result = cp_generic_propagation_constraints(self)

if len(result) == 2:
variables, constraints = result
metadata = {}
elif len(result) == 3:
variables, constraints, metadata = result
else:
raise ValueError("Unexpected return value from component generator")
result = cp_generic_propagation_constraints(context, state)

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
self._variables_list.extend(result.declarations)
self._model_constraints.extend(result.constraints)
metadata = result.metadata

if metadata:
probability_var = metadata.get("probability_var")
Expand All @@ -222,6 +216,9 @@ def build_generic_cp_model_from_dictionary(self, component_and_model_types, fixe
if round_index < len(self.probability_modadd_vars_per_round):
self.probability_modadd_vars_per_round[round_index].append(probability_var)

# Apply accumulated state back to model
state.apply_to_model(self)

def build_mix_column_truncated_table(self, component):
"""
Return a model that generates the list of possible input/output couples for the given mix column.
Expand Down
17 changes: 8 additions & 9 deletions claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def build_cipher_model(self, fixed_variables=[], second=False):
"""
self.initialise_model()
self._model_prefix.extend(self.input_constraints())
self.sbox_mant = []
variables = []
self.sbox_table_cache = []
self._variables_list = []
constraints = self.fix_variables_value_constraints(fixed_variables)
component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION)
Expand All @@ -70,14 +69,14 @@ def build_cipher_model(self, fixed_variables=[], second=False):
WORD_OPERATION == component.type and operation not in operation_types
):
print(f"{component.id} not yet implemented")
continue
if component.type != SBOX:
result = component.cp_constraints()
else:
if component.type != SBOX:
variables, constraints = component.cp_constraints()
else:
variables, constraints = component.cp_constraints(self.sbox_mant)
result = component.cp_constraints(self.sbox_table_cache)

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
self._model_constraints.extend(result.constraints)
self._variables_list.extend(result.declarations)

self._model_constraints.extend(self.final_constraints())

Expand Down Expand Up @@ -138,7 +137,7 @@ def input_constraints(self):
...
'array[0..31] of var 0..1: cipher_output_3_12;']
"""
self.sbox_mant = []
self.sbox_table_cache = []
cp_declarations = [
f"array[0..{bit_size - 1}] of var 0..1: {input_};"
for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# ****************************************************************************


from claasp.cipher_modules.models.cp.cp_build_context import CpBuildContext
from claasp.cipher_modules.models.cp.cp_build_state import CpBuildState
from claasp.cipher_modules.models.cp.mzn_model import MznModel
from claasp.name_mappings import CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION

Expand Down Expand Up @@ -47,9 +48,10 @@ def build_cipher_model(self, fixed_variables=[]):
...
"""
self._variables_list = []
variables = []
constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables)
self._model_constraints = constraints
context = CpBuildContext.from_model(self)
state = CpBuildState.from_model(self)
component_types = [CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION]
operation_types = ["ROTATE", "SHIFT", "XOR"]

Expand All @@ -59,8 +61,10 @@ def build_cipher_model(self, fixed_variables=[]):
WORD_OPERATION == component.type and operation not in operation_types
):
print(f"{component.id} not yet implemented")
else:
variables, constraints = component.minizinc_constraints(self)
continue
result = component.minizinc_constraints(context, state)

self._model_constraints.extend(result.constraints)
self._variables_list.extend(result.declarations)

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
state.apply_to_model(self)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from minizinc import Status

from claasp.cipher_modules.models.cp.cp_build_context import CpBuildContext
from claasp.cipher_modules.models.cp.cp_build_state import CpBuildState
from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY
from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT
from claasp.cipher_modules.models.utils import (
Expand Down Expand Up @@ -492,13 +494,14 @@ def output_inverse_constraints(self, component):
def propagate_deterministically(self, component, wordwise=False, inverse=False):
if not wordwise:
if component.type == SBOX:
variables, constraints, sbox_mant = (
component.cp_deterministic_truncated_xor_differential_trail_constraints(self.sbox_mant, inverse)
)
self.sbox_mant = sbox_mant
result = component.cp_deterministic_truncated_xor_differential_trail_constraints(self.sbox_table_cache, inverse)
self.sbox_table_cache = result.metadata
else:
variables, constraints = component.cp_deterministic_truncated_xor_differential_trail_constraints()
result = component.cp_deterministic_truncated_xor_differential_trail_constraints()
else:
variables, constraints = component.cp_wordwise_deterministic_truncated_xor_differential_constraints(self)
context = CpBuildContext.from_model(self)
state = CpBuildState.from_model(self)
result = component.cp_wordwise_deterministic_truncated_xor_differential_constraints(context, state)
state.apply_to_model(self)

return variables, constraints
return result.declarations, result.constraints
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia
sage: minizinc.build_deterministic_truncated_xor_differential_trail_model()
...
"""
variables = []
constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables)
self._variables_list = []
self._model_constraints = constraints
Expand All @@ -57,11 +56,12 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia
operation_types = ["ROTATE", "SHIFT"]

if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types):
variables, constraints = component.minizinc_deterministic_truncated_xor_differential_trail_constraints(
result = component.minizinc_deterministic_truncated_xor_differential_trail_constraints(
self
)
else:
print(f"{component.id} not yet implemented")
continue

self._variables_list.extend(variables)
self._model_constraints.extend(constraints)
self._variables_list.extend(result.declarations)
self._model_constraints.extend(result.constraints)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from sage.combinat.permutation import Permutation
from sage.crypto.sbox import SBox

from claasp.cipher_modules.models.cp.cp_build_context import CpBuildContext
from claasp.cipher_modules.models.cp.cp_build_state import CpBuildState
from claasp.cipher_modules.models.cp.mzn_model import SOLVE_SATISFY, MznModel
from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import (
MznImpossibleXorDifferentialModel,
Expand Down Expand Up @@ -721,23 +723,26 @@ def propagate_deterministically(
if not wordwise:
if component.type == SBOX:
if key_schedule and probabilistic:
variables, constraints = component.cp_xor_differential_propagation_constraints(self, inverse)
state = CpBuildState.from_model(self)
result = component.cp_xor_differential_propagation_constraints(None, state, inverse)
state.apply_to_model(self)
else:
variables, constraints, sbox_mant = (
component.cp_hybrid_deterministic_truncated_xor_differential_constraints(
self.sbox_mant, inverse, self.sboxes_component_number_list
)
result = component.cp_hybrid_deterministic_truncated_xor_differential_constraints(
self.sbox_table_cache, inverse, self.sboxes_component_number_list
)
self.sbox_mant = sbox_mant
self.sbox_table_cache = result.metadata
self.sbox_size = component.output_bit_size
elif component.description[0] == "XOR":
variables, constraints = component.cp_hybrid_deterministic_truncated_xor_differential_constraints()
result = component.cp_hybrid_deterministic_truncated_xor_differential_constraints()
else:
variables, constraints = component.cp_deterministic_truncated_xor_differential_trail_constraints()
result = component.cp_deterministic_truncated_xor_differential_trail_constraints()
else:
variables, constraints = component.cp_wordwise_deterministic_truncated_xor_differential_constraints(self)
context = CpBuildContext.from_model(self)
state = CpBuildState.from_model(self)
result = component.cp_wordwise_deterministic_truncated_xor_differential_constraints(context, state)
state.apply_to_model(self)

return variables, constraints
return result.declarations, result.constraints

def format_component_value(self, component_id, string):
if f"{component_id}_i" in string:
Expand Down
Loading
Loading