"""Generator for constraints."""
import builtins
import logging
import pathlib
import random
from typing import Iterator, List, Optional, Tuple, Union
from jinja2 import Environment, FileSystemLoader, Template
from z3 import BitVec, Solver, sat
from .constraint import (
CompareConstraint,
Constraint,
OperationConstraint,
)
from .operator import OPERATORS, REVERSIBLE_OPERATORS
# Load jinja Environnement
HERE = pathlib.Path(__file__).parent.resolve()
TEMPLATES = HERE / "templates"
# Logging
logger = logging.getLogger(__name__)
class NoMoreConstraintError(RecursionError):
"""No constraints can be generated."""
[docs]class Z3Armor:
[docs] def __init__(
self,
secret: bytes,
random_state: Optional[int] = None,
) -> None:
"""Instantiated Z3Armor."""
self.indexes = {i: 0 for i in range(len(secret))}
self.constraints: List[Constraint] = []
if random_state is not None:
self.rand = random.Random(random_state) # noqa: S311
else:
self.rand = random.SystemRandom()
self.secret = secret
[docs] def __str__(self) -> str:
"""Represent the ConstraintGenerator."""
return (
f"<ConstraintGenerator constraints={len(self.constraints)} "
f"secret={self.secret!r}>"
)
[docs] def solver(self) -> Tuple[Solver, List[BitVec]]:
"""Create the solver."""
solver = Solver()
terms = [BitVec(f"p[{i}]", 8) for i in range(len(self.secret))]
for const in self.constraints:
solver.add(const.apply(terms))
return solver, terms
[docs] def verify(self, secret: bytes) -> bool:
"""Verify a password on the constraint generator."""
return all(const.check(secret) for const in self.constraints)
[docs] def complete(self) -> bool:
"""Check that the solver has a valid unique solution."""
# Get the current solver
solver, terms = self.solver()
# Check if model is sat
result = solver.check()
if result != sat:
logger.debug("No complete because model solver is %s", result)
return False
# Check if one of secret characters is None
model = solver.model()
solution = [model[c] for c in terms]
if any(c is None for c in solution):
logger.debug("No complete because some char are missing")
return False
# Check if first guess is the secret
guess: Optional[bytes] = bytes(c.as_long() for c in solution)
if guess != self.secret:
logger.debug("No complete because first guess is wrong")
return False
# Check if there are only one guess
guess = None
for secret in self.solutions():
if guess:
logger.debug("No complete because other guess exist")
return False
guess = secret
if guess is None:
raise TypeError
# All is ok so return True
logger.debug("Complete")
return True
[docs] def solutions(self) -> Iterator[bytes]:
"""Find all solution for the currents constraints."""
# ref: https://stackoverflow.com/questions/11867611/z3py-checking-all-solutions-for-equation
# ref: https://theory.stanford.edu/%7Enikolaj/programmingz3.html#sec-blocking-evaluations
solver, terms = self.solver()
previous = set()
def _all_smt_rec(terms: List[BitVec]) -> Iterator[bytes]:
if sat == solver.check():
model = solver.model()
solution = [model[c] for c in terms]
if all(c is not None for c in solution):
password = bytes(c.as_long() for c in solution)
if password not in previous:
previous.add(password)
yield password
for i in range(len(terms)):
solver.push()
try:
solver.add(terms[i] != model.eval(terms[i]))
for j in range(i):
solver.add(terms[j] == model.eval(terms[j]))
yield from _all_smt_rec(terms[i:])
finally:
solver.pop()
yield from _all_smt_rec(terms)
[docs] def weights(self) -> List[float]:
"""Computes the weights of each index."""
total = sum(count for count in self.indexes.values())
if total == 0:
return [1 / len(self.indexes)] * len(self.indexes)
return [count / total for count in self.indexes.values()]
[docs] def generate(self, rand: Optional[random.Random] = None) -> Constraint:
"""Generate a new constraint and it into the generator."""
# Generate an temporary generator
if rand is None:
seed = self.rand.random()
rand = random.Random(seed) # noqa: S311
used_indexes = []
const: Constraint
# Create first constraints
if rand.randint(0, 1) == 0:
x, y = self.weighted_sampling(rand, 2)
used_indexes.append(x)
used_indexes.append(y)
rev_op = rand.choice(list(REVERSIBLE_OPERATORS.values()))
n = rev_op.reverse(self.secret[y], self.secret[x]) % 256
# Avoid linkage of the same letter
if n == 0:
try:
return self.generate(rand=rand)
except RecursionError:
error_message = "No more constraint can be generated."
raise NoMoreConstraintError(error_message) from None
const = CompareConstraint(x, y, rev_op, n)
# Create a OperationConstraint
else:
x, y = self.weighted_sampling(rand, 2)
used_indexes.append(x)
used_indexes.append(y)
op = rand.choice(list(OPERATORS.values()))
n = op(self.secret[x], self.secret[y]) % 256
const = OperationConstraint(x, y, op, n)
# Regenerate already existing constants
if const in self.constraints:
try:
return self.generate(rand=rand)
except RecursionError:
error_message = "No more constraint can be generated."
raise NoMoreConstraintError(error_message) from None
# Use index
for i in used_indexes:
self.indexes[i] += 1
self.constraints.append(const)
logger.info("Generate new constraint %s", const)
return const
[docs] def fit(self) -> None:
"""Make solver sat."""
while not self.complete():
try:
self.generate()
except NoMoreConstraintError: # noqa: PERF203
logger.info("No more constraints")
break
self.reduce()
if not self.complete():
error_message = "Model is not sat after reduce"
raise RuntimeError(error_message)
[docs] def weighted_sampling(
self,
rand: random.Random,
k: int,
) -> List[int]:
"""Get index using sampling."""
if k > len(self.indexes):
error_message = "Not enough indexes."
raise IndexError(error_message)
indexes = self.indexes.copy()
chosen = []
for _ in range(k):
max_count = max(indexes.values())
if min(indexes.values()) == max_count:
max_count += 1
weights = [max_count - v for v in indexes.values()]
(choose,) = rand.choices(
list(indexes.keys()), k=1, weights=weights
)
chosen.append(choose)
indexes.pop(choose)
return chosen
[docs] def reduce(self) -> None:
"""Minimize the number of constraints."""
logger.info("Start reduction")
start_size = len(self.constraints)
while True:
for constraint in self.constraints:
self.constraints.remove(constraint)
if self.complete():
break
self.constraints.append(constraint)
else:
break
logger.info(
"Reduction result: %s to %s",
start_size,
len(self.constraints),
)