Appendix I: The Discrete Math Toolkit (Full Code)

Throughout this book you have been building one thing, a little at a time: a small Python package called dmtoolkit. Every chapter's Project Checkpoint added a function or two — never more than thirty lines, always a direct transcription of a definition or theorem you had just learned. The idea was that mathematics you can run is mathematics you understand. By the end you have not just read about counting, modular arithmetic, and graph algorithms; you have implemented them from scratch.

This appendix assembles all of those increments into one coherent reference. It is the complete package, module by module, reconciled into a single consistent version. Where two chapters added the same function (for instance, both Chapter 22 and Chapter 23 needed gcd, and Chapter 24 and Chapter 25 both touched RSA), the checkpoints repeated code so each chapter could stand alone. Here those duplicates are merged into one canonical definition, the signatures are exactly the frozen ones from the toolkit API, and the graph functions all share a single Graph class instead of the per-chapter stand-ins.

🔗 Connection: This is the reference version. Each chapter's code/project-checkpoint.py is the teaching version, with extra comments and a self-contained __main__ demo. When the two differ, this appendix wins — it is what the assembled dmtoolkit/ directory should contain.

A few conventions hold across every module:

  • The code is never executed in this book. Every # Expected output: comment was derived by hand. That is deliberate: theme four of this book is that computation and proof are complementary. Running a function on one input tells you what it did once; a proof tells you what it does always. Use both.
  • Clarity over cleverness. Each function stays under thirty lines and prefers a readable loop to a dense one-liner. Where a real library would reach for networkx, sympy, or numpy, we implement the idea from scratch first, because the point is to see the mathematics, not to hide it.
  • Standard library only. The package imports nothing beyond itertools, collections, fractions, heapq, math, and random. It runs on any Python 3.10 or later with no installation.

The suggested package layout:

dmtoolkit/
├── __init__.py
├── logic.py          # Chapters 1–3
├── sets.py           # Chapter 8
├── relations.py      # Chapters 12–13
├── combinatorics.py  # Chapters 15–17
├── recurrences.py    # Chapters 18–19
├── probability.py    # Chapters 20–21
├── numbertheory.py   # Chapters 22–23
├── crypto.py         # Chapters 24–25
├── graphs.py         # Chapters 27–34
└── coding.py         # Chapters 26, 38

logic.py — propositions, quantifiers, and valid arguments

Built in Chapters 1–3. A compound proposition is just a Python function that returns a bool, and "every truth assignment" is itertools.product over [True, False]. From that single idea the module gives you truth tables, a tautology checker, an equivalence checker, finite-domain quantifiers, a counterexample finder, and a decision procedure for argument validity. Chapter 1 contributed truth_table, is_tautology, and equivalent; Chapter 2 added the quantifiers and counterexample (making the negation law $\neg(\forall x\, P(x)) \equiv \exists x\, \neg P(x)$ executable); Chapter 3 added is_valid, which decides whether an argument is valid by the textbook definition — no assignment makes every premise true while the conclusion is false.

"""dmtoolkit/logic.py -- propositional and predicate logic (Chapters 1-3).

A compound proposition is a function of named boolean variables; "every
assignment" is itertools.product over [True, False].
"""
from itertools import product


def truth_table(fn, names):
    """Truth table of fn as a list of (assignment_tuple, output) rows."""
    rows = []
    for assignment in product([True, False], repeat=len(names)):
        rows.append((assignment, fn(*assignment)))
    return rows


def is_tautology(fn, n):
    """True iff the n-ary boolean function fn is true on every assignment."""
    return all(fn(*a) for a in product([True, False], repeat=n))


def equivalent(f, g, n):
    """True iff n-ary boolean functions f and g agree on every assignment."""
    return all(f(*a) == g(*a) for a in product([True, False], repeat=n))


def for_all(predicate, domain):
    """Evaluate (for all x in domain) predicate(x). Vacuously True if empty."""
    return all(predicate(x) for x in domain)


def there_exists(predicate, domain):
    """Evaluate (exists x in domain) predicate(x). False if domain is empty."""
    return any(predicate(x) for x in domain)


def counterexample(predicate, domain):
    """First x in domain with predicate(x) False (a witness to the negation),
    or None if predicate holds for all of domain."""
    for x in domain:
        if not predicate(x):
            return x
    return None


def is_valid(premises, conclusion, names):
    """True iff the argument (premises therefore conclusion) is valid: no
    assignment makes every premise True while the conclusion is False.

    premises  : list of callables, each taking len(names) booleans -> bool
    conclusion: one such callable
    names     : variable-name strings (fixes the arity and order)
    """
    n = len(names)
    for vals in product([False, True], repeat=n):
        if all(p(*vals) for p in premises) and not conclusion(*vals):
            return False                       # a counterexample assignment
    return True


# Expected behavior (hand-derived):
#   is_tautology(lambda p: p or (not p), 1)                       -> True
#   equivalent(lambda p, q: not (p and q),
#              lambda p, q: (not p) or (not q), 2)                -> True   (De Morgan)
#   counterexample(lambda p: p % 2 == 1, [2, 3, 5, 7])           -> 2      ("every prime is odd"?)
#   is_valid([lambda p, q: (not p) or q, lambda p, q: p],
#            lambda p, q: q, ["p", "q"])                          -> True   (modus ponens)

⚠️ Common Pitfall: is_tautology and is_valid take an arity n or a names list, not a domain. They quantify over the $2^n$ boolean assignments. The quantifiers for_all / there_exists, by contrast, range over an explicit finite domain you supply. Mixing the two up is the most common early mistake — propositional truth lives over {True, False}; predicate truth lives over whatever universe you hand it.


sets.py — set operations from their definitions

Built in Chapter 8. Each function here is a one-line transcription of a set-builder definition: union is $\{x \mid x \in A \lor x \in B\}$, intersection is $\{x \mid x \in A \land x \in B\}$, and so on. power_set is the one with real algorithmic content — it doubles the collection of subsets each time it absorbs a new element, which is exactly why $\lvert \mathcal{P}(S)\rvert = 2^{\lvert S\rvert}$. These signatures are kept stable because later modules compose them: relations are sets of ordered pairs, and cartesian is how you build the universe a relation lives in.

"""dmtoolkit/sets.py -- finite set operations from scratch (Chapter 8).

Each function transcribes a set-builder definition over explicit collections.
"""


def union(a, b):
    """{ x | x in a or x in b }."""
    return {x for x in a} | {x for x in b}


def intersection(a, b):
    """{ x | x in a and x in b }."""
    return {x for x in a if x in b}


def difference(a, b):
    """{ x | x in a and x not in b }."""
    return {x for x in a if x not in b}


def cartesian(a, b):
    """{ (x, y) | x in a and y in b }."""
    return {(x, y) for x in a for y in b}


def power_set(s):
    """{ A | A subseteq s }, as a set of frozensets. Size is 2 ** |s|.
    Doubles the collection per element: each subset, with and without x."""
    subsets = {frozenset()}                       # start with the empty set
    for x in s:                                   # peel off one element at a time
        subsets |= {sub | {x} for sub in subsets}
    return subsets


# Expected behavior (hand-derived), with A = {1, 2, 3}, B = {3, 4}:
#   sorted(union(A, B))        -> [1, 2, 3, 4]
#   sorted(intersection(A, B)) -> [3]
#   sorted(difference(A, B))   -> [1, 2]
#   len(power_set({1, 2, 3}))  -> 8         ( = 2 ** 3 )
#   len(cartesian({1, 2}, {"x", "y", "z"})) -> 6   ( = 2 * 3 )

💡 Intuition: power_set returns frozensets, not sets, because the result is a set of sets — and Python sets can only contain hashable (immutable) elements. A frozenset is just a set you cannot change, which makes it hashable. This is the same reason cartesian returns tuples, not lists.


relations.py — equivalence, transitive closure, and topological sort

Built in Chapters 12–13. A relation is a set of ordered pairs (from Chapter 8), so checking its properties is just checking membership. is_equivalence verifies reflexivity, symmetry, and transitivity directly. closure_transitive repeatedly adds the pairs forced by transitivity until nothing new appears — a fixed-point computation. topo_sort (Chapter 13) implements Kahn's algorithm: the "repeatedly remove a minimal element" proof made efficient by tracking each node's in-degree. It is the engine behind every build system, package manager, and task scheduler.

"""dmtoolkit/relations.py -- properties, closure, and ordering (Chapters 12-13).

A relation is a set of (a, b) pairs over a domain. Properties are membership
checks; topo_sort orders a DAG via Kahn's algorithm.
"""


def is_equivalence(rel, dom):
    """True iff rel is reflexive, symmetric, and transitive on domain dom."""
    reflexive = all((a, a) in rel for a in dom)
    symmetric = all((b, a) in rel for (a, b) in rel)
    transitive = all((a, c) in rel
                     for (a, b) in rel for (x, c) in rel if b == x)
    return reflexive and symmetric and transitive


def closure_transitive(rel):
    """Smallest transitive relation containing rel (reachability via paths)."""
    S = set(rel)
    while True:
        new = {(a, c) for (a, b) in S for (x, c) in S if b == x}
        if new <= S:                 # a full pass added nothing -> stable
            return S
        S |= new


def topo_sort(dag):
    """Topologically sort a DAG given as {node: [successors]}. Every edge
    u -> v puts u before v. Raises ValueError if the graph has a cycle."""
    indeg = {u: 0 for u in dag}
    for u in dag:
        for v in dag[u]:
            indeg[v] = indeg.get(v, 0) + 1
            indeg.setdefault(u, indeg[u])      # ensure u is tracked
    ready = sorted(u for u in indeg if indeg[u] == 0)   # minimal elements
    order = []
    while ready:
        u = ready.pop(0)                       # take a minimal element
        order.append(u)
        for v in dag.get(u, []):               # "delete" u from the graph
            indeg[v] -= 1
            if indeg[v] == 0:
                ready.append(v)
        ready.sort()
    if len(order) != len(indeg):
        raise ValueError("graph has a cycle; no topological order exists")
    return order


# Expected behavior (hand-derived):
#   cong3 = {(a, b) for a in range(6) for b in range(6) if (a - b) % 3 == 0}
#   is_equivalence(cong3, set(range(6)))                  -> True
#   is_equivalence({(0, 1)}, {0, 1})                      -> False
#   sorted(closure_transitive({(1, 2), (2, 3), (3, 4)}))
#       -> [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
#   topo_sort({'lexer.o': ['parser.o', 'main'], 'parser.o': ['main'],
#              'utils.o': ['main'], 'main': []})
#       -> ['lexer.o', 'parser.o', 'utils.o', 'main']

🔗 Connection: topo_sort here works directly on a DAG. In Chapter 28 you re-derive a topological order from depth-first search instead — same result, different machinery. The two views are worth holding side by side: this one says "process sources first," the DFS one says "a node finishes after all its descendants."


combinatorics.py — counting and generating

Built in Chapters 15–17. This module both counts arrangements and generates them, and a recurring exercise is checking that the two agree (the number generated must equal the number counted). Chapter 15 contributed the basic counting rules (product_rule, sum_rule, inclusion_exclusion_2) and the first version of perm_count; Chapter 16 supplied the canonical perm_count and comb_count as falling products plus the generators permutations and combinations; Chapter 17 added stars_and_bars (for distributing identical objects) and derangement_count (permutations with no fixed point). Chapter 15 and Chapter 16 each defined perm_count; the two are mathematically identical, so here we keep the single Chapter 16 falling-product version.

"""dmtoolkit/combinatorics.py -- counting and generating (Chapters 15-17).

Counters return how many; generators yield the arrangements themselves.
A standing sanity check: len(list(generator)) == counter.
"""
from math import comb


def product_rule(*step_counts):
    """Independent steps with step_counts[i] options each: their product."""
    total = 1
    for n in step_counts:
        total *= n
    return total


def sum_rule(*case_counts):
    """Mutually EXCLUSIVE cases (caller guarantees disjointness): their sum."""
    return sum(case_counts)


def inclusion_exclusion_2(size_a, size_b, size_a_and_b):
    """|A union B| = |A| + |B| - |A intersect B|."""
    return size_a + size_b - size_a_and_b


def perm_count(n, k):
    """r-permutations P(n, k) = n!/(n-k)!, via a falling product.
    Returns 0 if k is out of range; perm_count(n, 0) == 1 (empty product)."""
    if k < 0 or k > n:
        return 0
    result = 1
    for i in range(n, n - k, -1):
        result *= i
    return result


def comb_count(n, k):
    """r-combinations C(n, k) = n!/(k!(n-k)!), kept exact and integer."""
    if k < 0 or k > n:
        return 0
    k = min(k, n - k)                        # symmetry: keep k small
    result = 1
    for i in range(k):
        result = result * (n - i) // (i + 1)
    return result


def permutations(it, k=None):
    """Yield k-permutations (ordered, no repetition) of the items in it."""
    pool = list(it)
    n = len(pool)
    k = n if k is None else k
    if k == 0:
        yield ()
        return
    for i in range(n):
        rest = pool[:i] + pool[i + 1:]           # remove the chosen element
        for tail in permutations(rest, k - 1):   # order the remaining k-1
            yield (pool[i],) + tail


def combinations(it, k):
    """Yield k-combinations (unordered, no repetition) of the items in it."""
    pool = list(it)
    n = len(pool)
    if k == 0:
        yield ()
        return
    for i in range(n - k + 1):                   # leave room for k-1 more
        for tail in combinations(pool[i + 1:], k - 1):   # later items only
            yield (pool[i],) + tail


def stars_and_bars(n, k):
    """Ways to put n identical objects into k distinct boxes (boxes may be
    empty); equivalently the count of non-negative solutions of
    x_1 + ... + x_k = n. The stars-and-bars bijection gives C(n+k-1, k-1)."""
    return comb(n + k - 1, k - 1)


def derangement_count(n):
    """Permutations of n objects with NO fixed point, via the recurrence
    D_n = (n-1)(D_{n-1} + D_{n-2}), with D_0 = 1, D_1 = 0."""
    if n == 0:
        return 1
    a, b = 1, 0                         # D_0, D_1
    for k in range(2, n + 1):
        a, b = b, (k - 1) * (a + b)
    return b


# Expected behavior (hand-derived):
#   perm_count(4, 2)                 -> 12      and len(list(permutations([1,2,3,4], 2))) -> 12
#   comb_count(4, 2)                 -> 6       and len(list(combinations([1,2,3,4], 2))) -> 6
#   list(combinations([1, 2, 3], 2)) -> [(1, 2), (1, 3), (2, 3)]
#   stars_and_bars(12, 4)            -> 455     ( = C(15, 3) )
#   derangement_count(5)             -> 44

🧩 Productive Struggle: Before trusting comb_count, convince yourself the running update result = result * (n - i) // (i + 1) never needs a fraction. The invariant is that after step i the value equals $\binom{n}{i+1}$, which is always an integer — so the integer division is exact at every step, not just at the end. That is why we can use // and never lose precision.


recurrences.py — solving recurrences, and the fast Fibonacci

Built in Chapters 18–19. Chapter 18 added solve_linear, which evaluates the $n$th term of any linear recurrence $a_n = c_1 a_{n-1} + \dots + c_k a_{n-k}$ by faithfully unrolling it in $O(nk)$ time — general, but not yet clever. Chapter 19 added fib, which computes the $n$th Fibonacci number in $O(\log n)$ time by raising the matrix $M = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}$ to the $n$th power via repeated squaring, using exact integer arithmetic so there is no floating-point error. The Fibonacci thread runs through the whole book; this is its computational payoff.

"""dmtoolkit/recurrences.py -- linear recurrences and fast Fibonacci (Ch. 18-19).

solve_linear unrolls any linear recurrence in O(n*k); fib uses O(log n) matrix
exponentiation with exact integers.
"""


def solve_linear(coeffs, inits, n):
    """nth term of a linear recurrence.
    coeffs = [c1,...,ck]  (c1 multiplies the most recent term a_{n-1});
    inits  = [a0,...,a_{k-1}]  (k initial conditions). Returns a_n."""
    k = len(coeffs)
    if n < k:
        return inits[n]                  # the answer is an initial condition
    window = list(inits)                 # last k terms, oldest first
    for _ in range(k, n + 1):
        nxt = sum(coeffs[i] * window[-1 - i] for i in range(k))
        window.append(nxt)
        window.pop(0)                    # slide: keep only the last k
    return window[-1]


def _mat_mult(A, B):
    """Multiply two 2x2 integer matrices given as ((a, b), (c, d))."""
    (a, b), (c, d) = A
    (e, f), (g, h) = B
    return ((a * e + b * g, a * f + b * h),
            (c * e + d * g, c * f + d * h))


def fib(n):
    """nth Fibonacci number (F_0 = 0, F_1 = 1) in O(log n) time via matrix
    exponentiation of M = [[1, 1], [1, 0]]; M^n has F_n in its top-right entry."""
    if n < 0:
        raise ValueError("fib(n) is defined for n >= 0")
    result = ((1, 0), (0, 1))            # identity (M^0)
    base = ((1, 1), (1, 0))              # M
    while n > 0:
        if n & 1:
            result = _mat_mult(result, base)
        base = _mat_mult(base, base)     # M, M^2, M^4, M^8, ...
        n >>= 1
    return result[0][1]


# Expected behavior (hand-derived):
#   solve_linear([1, 1], [0, 1], 10)     -> 55      (Fibonacci F_10)
#   solve_linear([5, -6], [1, 4], 5)     -> 454     (distinct-roots example)
#   solve_linear([2], [1], 8)            -> 256     (a_n = 2 a_{n-1}, a_0 = 1)
#   [fib(k) for k in range(10)]          -> [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
#   fib(50)                              -> 12586269025   (exact, no rounding)

🔄 Check Your Understanding: Why can fib use exact integer matrices while the closed-form Binet formula $F_n = (\varphi^n - \psi^n)/\sqrt 5$ cannot be trusted in floating point for large $n$?

Answer Binet's formula involves $\sqrt 5$, an irrational number that no floating-point value represents exactly, so rounding error grows with $n$. The matrix method only ever multiplies and adds integers, so every intermediate value is exact. Both run in $O(\log n)$ arithmetic operations; only the matrix method stays correct.


probability.py — estimate cheaply, compute exactly

Built in Chapters 20–21. This module is theme four in miniature. simulate (Chapter 20) estimates a probability by Monte-Carlo — run a random trial many times and report the fraction of successes; the law of large numbers says this converges, but it never proves an exact value. expected_value (Chapter 20) computes an exact expectation over a finite sample space, using Fraction for the uniform case so the answer is exact. Chapter 21 added the Bayesian update — bayes, total_probability, and the convenience bayes_update — the computation at the heart of every spam filter, medical test, and classifier.

"""dmtoolkit/probability.py -- estimate and compute (Chapters 20-21).

simulate ESTIMATES a probability by Monte-Carlo; expected_value computes an
EXACT expectation; bayes does the posterior update. Estimate cheaply, prove exactly.
"""
from fractions import Fraction
import random


def simulate(trial_fn, k):
    """Estimate P(event) by running trial_fn() k times. trial_fn() returns True
    when the event occurs. Returns the empirical fraction (an ESTIMATE, not a proof)."""
    return sum(1 for _ in range(k) if trial_fn()) / k


def expected_value(rv, space, weight=None):
    """Exact E[rv] over a finite sample space.
    rv: outcome -> number.  space: iterable of outcomes.
    weight: outcome -> probability (defaults to uniform / equally likely)."""
    space = list(space)
    if weight is None:                            # equally-likely (Laplace) model
        n = len(space)
        return sum(Fraction(rv(s), n) for s in space)
    return sum(rv(s) * weight(s) for s in space)


def total_probability(pa, pb_given_a, pb_given_not_a):
    """Evidence P(B) = P(B|A)P(A) + P(B|not A)P(not A) (the Bayes denominator)."""
    return pb_given_a * pa + pb_given_not_a * (1 - pa)


def bayes(pa, pb_given_a, pb):
    """Posterior P(A | B) = P(B|A) * P(A) / P(B).
    pa = prior P(A); pb_given_a = likelihood P(B|A); pb = evidence P(B)."""
    if pb == 0:
        raise ValueError("P(B) must be positive to condition on B")
    return pb_given_a * pa / pb


def bayes_update(pa, pb_given_a, pb_given_not_a):
    """Posterior P(A|B), computing the denominator for you (two-hypothesis case)."""
    pb = total_probability(pa, pb_given_a, pb_given_not_a)
    return bayes(pa, pb_given_a, pb)


# Expected behavior (hand-derived):
#   space = list(itertools.product(range(1, 7), repeat=2))   # 36 equally likely rolls
#   expected_value(lambda s: s[0] + s[1], space)  -> Fraction(7, 1)   (exact mean is 7)
#   simulate(lambda: ..., 100_000)                ~ 0.167  (approx; true P(sum==7) = 1/6)
#   round(bayes_update(0.001, 0.99, 0.01), 4)     -> 0.0902
#       (rare disease: 0.1% base rate, 99% sensitivity, 1% false positive)

⚠️ Common Pitfall: The rare-disease number surprises everyone: a test that is "99% accurate" still leaves a positive result more likely wrong than right when the disease is rare. bayes_update makes the reason concrete — the denominator $P(B)$ is dominated by false positives drawn from the huge healthy population. Always compute the evidence term; never reason from the likelihood alone.


numbertheory.py — the engine room of RSA

Built in Chapters 22–23. This is where the cryptography becomes possible. Chapter 22 contributed gcd and ext_gcd (the extended Euclidean algorithm, which also returns the Bézout coefficients) and sieve (the Sieve of Eratosthenes). Chapter 23 added the three modular routines RSA depends on: mod_inverse, mod_pow (square-and-multiply, the same fast-exponentiation trick as the matrix fib), and crt (the Chinese Remainder Theorem). Chapter 22 and Chapter 23 each defined gcd/ext_gcd; the consolidated version below uses the Chapter 22 forms, which handle negative inputs cleanly. We also include is_probable_prime and random_prime here — Chapters 23 and 25 refer to "the primality test that ships with numbertheory.py," and crypto.py's full rsa_keygen(bits) needs them.

"""dmtoolkit/numbertheory.py -- the engine room of RSA (Chapters 22-23).

gcd / ext_gcd / sieve from Ch. 22; mod_inverse / mod_pow / crt from Ch. 23.
is_probable_prime / random_prime support key generation in crypto.py.
"""
import random


def gcd(a, b):
    """Greatest common divisor via the Euclidean algorithm. gcd(0, 0) = 0."""
    a, b = abs(a), abs(b)
    while b != 0:
        a, b = b, a % b
    return a


def ext_gcd(a, b):
    """Return (g, x, y) with g = gcd(a, b) and a*x + b*y = g (Bezout)."""
    if b == 0:
        return (a, 1, 0)
    g, x1, y1 = ext_gcd(b, a % b)
    return (g, y1, x1 - (a // b) * y1)


def sieve(n):
    """All primes p with 2 <= p <= n (Sieve of Eratosthenes)."""
    if n < 2:
        return []
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False
    for p in range(2, int(n ** 0.5) + 1):
        if is_prime[p]:
            for m in range(p * p, n + 1, p):
                is_prime[m] = False
    return [i for i in range(2, n + 1) if is_prime[i]]


def mod_pow(base, exp, mod):
    """base ** exp mod 'mod' in O(log exp) multiplications (square-and-multiply)."""
    result, base = 1, base % mod
    while exp > 0:
        if exp & 1:
            result = (result * base) % mod
        base = (base * base) % mod
        exp >>= 1
    return result


def mod_inverse(a, m):
    """a^{-1} mod m: the unique x in [0, m) with a*x congruent to 1,
    or None if gcd(a, m) != 1."""
    g, s, _ = ext_gcd(a % m, m)
    return s % m if g == 1 else None


def crt(residues, moduli):
    """Solve x congruent to residues[i] (mod moduli[i]) for pairwise-coprime
    moduli; return the unique x in [0, N) where N is the product of the moduli."""
    N = 1
    for m in moduli:
        N *= m
    x = 0
    for a_i, n_i in zip(residues, moduli):
        N_i = N // n_i
        x += a_i * N_i * mod_inverse(N_i, n_i)
    return x % N


def is_probable_prime(n, k=20):
    """Miller-Rabin probabilistic primality test with k random witnesses.
    Composite numbers are rejected with probability at least 1 - 4^{-k}."""
    if n < 2:
        return False
    for p in (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37):
        if n % p == 0:
            return n == p
    d, r = n - 1, 0
    while d % 2 == 0:                     # write n - 1 = d * 2^r with d odd
        d //= 2
        r += 1
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = mod_pow(a, d, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):           # repeated squaring looking for -1
            x = (x * x) % n
            if x == n - 1:
                break
        else:
            return False                 # a is a witness: n is composite
    return True


def random_prime(bits):
    """A random probable prime with the given bit length (top bit set so the
    product of two such primes has the intended size)."""
    while True:
        candidate = random.getrandbits(bits) | (1 << (bits - 1)) | 1
        if is_probable_prime(candidate):
            return candidate


# Expected behavior (hand-derived):
#   gcd(252, 198)            -> 18
#   ext_gcd(252, 198)        -> (18, 4, -5)        (since 252*4 + 198*(-5) = 18)
#   sieve(30)                -> [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
#   mod_inverse(7, 26)       -> 15                 (7 * 15 = 105 congruent to 1 mod 26)
#   mod_pow(3, 13, 7)        -> 3
#   crt([2, 3, 2], [3, 5, 7])-> 23                 (Sun Tzu's classic system)
# (is_probable_prime / random_prime use randomness; their output is not hand-traceable.)

🚪 Threshold Concept: Everything that makes public-key cryptography possible is in this one file. random_prime is fast because primes are abundant and is_probable_prime is cheap, yet factoring the product of two large primes is believed to be hard. That gap — easy to multiply, hard to factor — is the entire foundation of RSA. The functions here are easy; the security rests on a problem nobody knows how to make easy.


crypto.py — RSA

Built in Chapters 24–25. This module turns the number theory into a working public-key cryptosystem. Chapter 24 introduced key generation from the group structure of $\mathbb{Z}_n^{*}$ (whose order is $\phi(n) = (p-1)(q-1)$), and Chapter 25 added encryption and decryption and proved the round trip via Euler's theorem. The chapter checkpoints used slightly different signatures while each stood alone; here they are reconciled to the frozen toolkit API: a public key is the pair pub = (n, e), a private key is priv = (n, d), and rsa_encrypt(m, pub) / rsa_decrypt(c, priv) each take a single key. The bits-based rsa_keygen(bits) is the canonical entry point; rsa_keygen_from_primes is kept for the hand-traceable textbook example with the classic primes $p = 61$, $q = 53$. The heavy lifting — mod_inverse and random_prime — is imported from numbertheory.py.

"""dmtoolkit/crypto.py -- textbook RSA (Chapters 24-25).

Key pairs: pub = (n, e), priv = (n, d). Encryption is m^e mod n; decryption is
c^d mod n; correctness follows from Euler's theorem. EDUCATIONAL scale only --
real RSA needs padding (e.g. OAEP) and audited libraries.
"""
from dmtoolkit.numbertheory import mod_inverse, mod_pow, random_prime


def rsa_keygen_from_primes(p, q, e=65537):
    """Build an RSA key pair from chosen distinct primes p, q and public
    exponent e. Returns (pub, priv) = ((n, e), (n, d)). Requires
    gcd(e, (p-1)(q-1)) == 1, so that d = e^{-1} mod phi exists."""
    n = p * q
    phi = (p - 1) * (q - 1)              # |Z_n^*| = (p-1)(q-1)
    d = mod_inverse(e, phi)             # exists iff gcd(e, phi) == 1
    if d is None:
        raise ValueError("e must be coprime to phi(n)")
    return (n, e), (n, d)


def rsa_keygen(bits, e=65537):
    """Generate an RSA key pair whose modulus n has about 'bits' bits, by
    drawing two distinct random primes of half that size each."""
    half = bits // 2
    p = random_prime(half)
    q = random_prime(half)
    while q == p:
        q = random_prime(half)
    return rsa_keygen_from_primes(p, q, e)


def rsa_encrypt(m, pub):
    """Encrypt integer message m (with 0 <= m < n) under public key pub = (n, e)."""
    n, e = pub
    return mod_pow(m, e, n)             # m^e mod n


def rsa_decrypt(c, priv):
    """Decrypt ciphertext c under private key priv = (n, d)."""
    n, d = priv
    return mod_pow(c, d, n)             # c^d mod n


# Expected behavior (hand-derived) with the classic textbook primes:
#   pub, priv = rsa_keygen_from_primes(61, 53, e=17)
#       -> pub = (3233, 17),  priv = (3233, 2753)
#          (n = 3233, phi = 3120, and 17 * 2753 = 46801 = 15*3120 + 1, so d = 2753)
#   c = rsa_encrypt(65, pub)            -> 2790      (65^17 mod 3233)
#   rsa_decrypt(c, priv)               -> 65         (the original message)
# (rsa_keygen(bits) draws random primes, so its output is not hand-traceable.)

⚠️ Common Pitfall: This is textbook RSA — perfect for learning, unsafe for production. Real systems add randomized padding (OAEP), enforce $m < n$, and never roll their own primitives. The value of writing it yourself is understanding why it works, not shipping it. When you need cryptography for real, reach for an audited library, not this file.


graphs.py — the algorithmic heart of Part V

Built in Chapters 27–34. This is the largest module, and it is built around one shared Graph class (Chapter 27) that every later algorithm traverses. Because each chapter's checkpoint had to stand alone, the chapters each shipped a minimal stand-in Graph; here they are unified into a single class that supports both weighted and unweighted, directed and undirected graphs, so all the algorithms compose on the same structure. The build order: Chapter 27 the Graph class and the handshaking-lemma helpers; Chapter 28 bfs and dfs; Chapter 29 dijkstra (shortest paths over non-negative weights); Chapter 30 Euler-circuit detection and construction (Hierholzer); Chapter 31 is_tree and tree_height; Chapter 32 mst_kruskal (with a UnionFind helper); Chapter 33 graph coloring; Chapter 34 max_flow (Edmonds–Karp). This is the social-network thread's home: BFS is "degrees of separation," Dijkstra is shortest paths between people, coloring finds conflict-free groupings, and matching pairs people up.

"""dmtoolkit/graphs.py -- graphs and their algorithms (Chapters 27-34).

One Graph class underlies everything: BFS/DFS (Ch. 28), Dijkstra (Ch. 29),
Euler circuits (Ch. 30), tree tests (Ch. 31), Kruskal MST (Ch. 32),
coloring (Ch. 33), and max-flow (Ch. 34).
"""
from collections import deque
import heapq
from itertools import product

INF = float("inf")


class Graph:
    """A graph on an adjacency dict mapping each vertex to a list of
    (neighbor, weight) pairs. Set directed=True for a digraph; weights
    default to 1 so unweighted algorithms can ignore them."""

    def __init__(self, directed=False):
        self.directed = directed
        self.adj = {}                          # vertex -> list of (neighbor, weight)

    def add_vertex(self, v):
        self.adj.setdefault(v, [])             # no-op if v already present

    def add_edge(self, u, v, weight=1):
        self.add_vertex(u)
        self.add_vertex(v)
        self.adj[u].append((v, weight))
        if not self.directed:                  # undirected: record both ways
            self.adj[v].append((u, weight))

    def vertices(self):
        return set(self.adj)

    def neighbors(self, v):
        return [w for (w, _) in self.adj[v]]   # neighbor vertices only

    def weighted_neighbors(self, v):
        return self.adj[v]                     # (neighbor, weight) pairs

    def degree(self, v):
        return len(self.adj[v])                # deg(v) = |N(v)| for a simple graph

    def edges(self):
        """List of (u, v, weight). For undirected graphs, each edge once."""
        out, seen = [], set()
        for u in self.adj:
            for v, w in self.adj[u]:
                key = (u, v) if self.directed else frozenset((u, v))
                if key not in seen:
                    seen.add(key)
                    out.append((u, v, w))
        return out

    def num_edges(self):
        """Handshaking lemma: |E| = (sum of degrees) / 2 (undirected)."""
        total = sum(len(nbrs) for nbrs in self.adj.values())
        return total if self.directed else total // 2


class UnionFind:
    """Disjoint-set forest with path halving (used by Kruskal's algorithm)."""

    def __init__(self, items):
        self.parent = {x: x for x in items}

    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]   # path halving
            x = self.parent[x]
        return x

    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx != ry:
            self.parent[rx] = ry
            return True                        # a real merge happened
        return False                           # already joined (would make a cycle)


def bfs(g, s):
    """Breadth-first search from s. Returns (dist, parent) over reachable
    vertices; dist[v] = number of edges on a shortest s -> v path."""
    dist, parent, queue = {s: 0}, {s: None}, deque([s])
    while queue:
        v = queue.popleft()
        for w in g.neighbors(v):
            if w not in dist:
                dist[w] = dist[v] + 1
                parent[w] = v
                queue.append(w)
    return dist, parent


def dfs(g, s):
    """Depth-first search from s. Returns reachable vertices in pre-order."""
    visited, order = set(), []

    def visit(v):
        visited.add(v)
        order.append(v)
        for w in g.neighbors(v):
            if w not in visited:
                visit(w)
    visit(s)
    return order


def dijkstra(g, s):
    """Single-source shortest paths from s over NON-NEGATIVE weights.
    Returns (dist, prev). Raises ValueError on a negative edge weight."""
    dist = {v: INF for v in g.vertices()}
    prev = {v: None for v in g.vertices()}
    dist[s] = 0
    pq = [(0, s)]
    while pq:
        d, u = heapq.heappop(pq)
        if d > dist[u]:
            continue                           # a stale queue entry
        for v, w in g.weighted_neighbors(u):
            if w < 0:
                raise ValueError("Dijkstra needs non-negative weights")
            if d + w < dist[v]:
                dist[v], prev[v] = d + w, u
                heapq.heappush(pq, (dist[v], v))
    return dist, prev


def shortest_path(g, s, t):
    """Reconstruct one shortest s -> t path as a vertex list ([] if none)."""
    dist, prev = dijkstra(g, s)
    if dist[t] == INF:
        return []
    route = []
    while t is not None:
        route.append(t)
        t = prev[t]
    return route[::-1]


def has_euler_circuit(g):
    """True iff connected g has an Euler circuit: every vertex has even degree."""
    return all(g.degree(v) % 2 == 0 for v in g.vertices())


def euler_circuit(g, start):
    """Hierholzer's algorithm: an Euler circuit as a vertex list, or None.
    Assumes g is connected with all even degrees; consumes a copy of the edges."""
    if not has_euler_circuit(g):
        return None
    adj = {v: list(g.neighbors(v)) for v in g.vertices()}
    stack, circuit = [start], []
    while stack:
        v = stack[-1]
        if adj[v]:                             # an unused edge out of v?
            w = adj[v].pop()                   # take it
            adj[w].remove(v)                   # remove the undirected twin
            stack.append(w)
        else:
            circuit.append(stack.pop())        # dead end: backtrack and record
    return circuit[::-1]


def is_tree(g):
    """True iff g is a tree: connected with |V| - 1 edges."""
    n = len(g.vertices())
    if n == 0:
        return False
    if g.num_edges() != n - 1:
        return False
    start = next(iter(g.vertices()))
    return len(bfs(g, start)[0]) == n          # reaches every vertex?


def tree_height(g, root):
    """Height of tree g rooted at root: the longest root-to-leaf path length."""
    def h(node, parent):
        kids = [w for w in g.neighbors(node) if w != parent]
        return 0 if not kids else 1 + max(h(c, node) for c in kids)
    return h(root, None)


def mst_kruskal(g):
    """Minimum spanning tree of a connected weighted graph g via Kruskal's
    algorithm. Returns (mst_edges, total_weight); mst_edges is (u, v, weight)."""
    uf = UnionFind(g.vertices())
    mst, total = [], 0
    for u, v, w in sorted(g.edges(), key=lambda e: e[2]):   # ascending weight
        if uf.union(u, v):                     # add only if it joins components
            mst.append((u, v, w))
            total += w
    return mst, total


def greedy_coloring(g, order=None):
    """Proper coloring via the greedy rule: give each vertex the smallest color
    none of its colored neighbors uses. Returns {vertex: color}. Not minimal."""
    order = list(g.vertices()) if order is None else order
    color = {}
    for v in order:
        used = {color[u] for u in g.neighbors(v) if u in color}
        c = 0
        while c in used:
            c += 1
        color[v] = c
    return color


def chromatic_number(g):
    """EXACT chromatic number chi(G) by brute force. Exponential -- small
    graphs only; included to check greedy_coloring on examples."""
    verts = list(g.vertices())
    for k in range(1, len(verts) + 1):
        for assign in product(range(k), repeat=len(verts)):
            c = dict(zip(verts, assign))
            if all(c[u] != c[v] for u in verts for v in g.neighbors(u)):
                return k
    return len(verts)


def max_flow(g, s, t):
    """Maximum flow value from s to t (Edmonds-Karp: BFS augmenting paths).
    Runs in O(V * E^2); returns an integer when all capacities are integers."""
    res = {u: {} for u in g.vertices()}        # residual capacities
    for u in g.vertices():
        for v, c in g.weighted_neighbors(u):
            res[u][v] = res[u].get(v, 0) + c
            res.setdefault(v, {}).setdefault(u, 0)   # reverse edge starts at 0
    total = 0
    while True:
        parent, q = {s: None}, deque([s])      # BFS for a shortest augmenting path
        while q:
            u = q.popleft()
            for v, c in res[u].items():
                if c > 0 and v not in parent:
                    parent[v] = u
                    q.append(v)
        if t not in parent:                    # no augmenting path: done
            return total
        v, bottleneck = t, INF                 # find the bottleneck capacity
        while parent[v] is not None:
            bottleneck = min(bottleneck, res[parent[v]][v])
            v = parent[v]
        v = t
        while parent[v] is not None:           # push flow, update residuals
            res[parent[v]][v] -= bottleneck
            res[v][parent[v]] += bottleneck
            v = parent[v]
        total += bottleneck


# Expected behavior (hand-derived):
#   Undirected G on Ana-Ben-Cam-Dev-Eve (5 edges):
#     g.degree("Cam") -> 3 ;  g.num_edges() -> 5   (handshaking: sum of degrees 10)
#   bfs/dfs on path-like 0-1-2-3-4: bfs dist to 4 -> ... ; dfs(g, 0) visits every vertex
#   Directed weighted A..E: dijkstra("A") dist
#     -> {'A': 0, 'B': 3, 'C': 1, 'D': 4, 'E': 7}
#        shortest_path(g, "A", "E") -> ['A', 'C', 'B', 'D', 'E']
#   Triangle 1-2-3: has_euler_circuit -> True ;  euler_circuit(g, 1) -> [1, 3, 2, 1]
#   Path 1-2-3-4: is_tree -> True ;  tree_height(g, 1) -> 3
#   Weighted A..E (6 edges): mst_kruskal
#     -> ([('B', 'C', 1), ('A', 'B', 2), ('B', 'D', 4), ('C', 'E', 5)], 12)
#   5-cycle C5: greedy_coloring uses 3 colors ;  chromatic_number -> 3
#   Flow network s->t (capacities 10,3,4,3,10): max_flow(g, "s", "t") -> 10

🔗 Connection: Notice how much reuse one good data structure buys. dijkstra is bfs with a priority queue instead of a plain queue; is_tree leans on both num_edges (the handshaking lemma from Chapter 27) and bfs (connectivity from Chapter 28); mst_kruskal needs only edges and UnionFind. Choosing the right abstraction once — a Graph of (neighbor, weight) pairs — is what lets ten algorithms share one foundation.


coding.py — error-detecting and error-correcting codes

Built in Chapters 26 and 38. Chapter 26 began this module with the Hamming(7,4) single-error- correcting code: hamming_distance (how many positions two codewords differ in), hamming_encode (four data bits to a seven-bit codeword, with parity bits at the power-of-two positions 1, 2, 4), and hamming_decode (which uses the syndrome to locate and flip a single error). Chapter 38 added the linear-algebra view that generalizes to any linear code: syndrome computes the parity-check product $H \mathbf{w}^{\mathsf T}$ over $\mathrm{GF}(2)$, and min_distance finds a linear code's minimum distance as the smallest weight of a nonzero codeword (Theorem 38.4).

"""dmtoolkit/coding.py -- error detection and correction (Chapters 26, 38).

Hamming(7,4) corrects any single-bit error; the syndrome/min_distance pair
gives the general linear-algebra view over GF(2).
"""

HAMMING74_H = [[0, 0, 0, 1, 1, 1, 1],          # rows = checks s4, s2, s1
               [0, 1, 1, 0, 0, 1, 1],          # column j (1-indexed) is j in binary
               [1, 0, 1, 0, 1, 0, 1]]


def hamming_distance(a, b):
    """Number of positions where equal-length bit lists a and b differ."""
    return sum(x != y for x, y in zip(a, b))


def hamming_encode(data):
    """Encode 4 data bits into a 7-bit Hamming codeword (parity at 1, 2, 4)."""
    d1, d2, d3, d4 = data                       # data sit at positions 3,5,6,7
    p1 = d1 ^ d2 ^ d4                           # parity over positions 1,3,5,7
    p2 = d1 ^ d3 ^ d4                           # parity over positions 2,3,6,7
    p4 = d2 ^ d3 ^ d4                           # parity over positions 4,5,6,7
    return [p1, p2, d1, p4, d2, d3, d4]         # positions 1..7


def hamming_decode(code):
    """Correct up to one bit error; return (data_bits, error_position).
    error_position is 0 if no error was detected."""
    c = code[:]
    s1 = c[0] ^ c[2] ^ c[4] ^ c[6]             # parity over positions 1,3,5,7
    s2 = c[1] ^ c[2] ^ c[5] ^ c[6]             # parity over positions 2,3,6,7
    s4 = c[3] ^ c[4] ^ c[5] ^ c[6]             # parity over positions 4,5,6,7
    syndrome_pos = s4 * 4 + s2 * 2 + s1        # binary s4 s2 s1 = bad position
    if syndrome_pos:
        c[syndrome_pos - 1] ^= 1               # flip the indicated bit
    return [c[2], c[4], c[5], c[6]], syndrome_pos


def syndrome(H, word):
    """Syndrome H * word^T over GF(2), as a list of bits (top row first)."""
    return [sum(H[i][j] * word[j] for j in range(len(word))) % 2
            for i in range(len(H))]


def min_distance(codewords):
    """Minimum distance of a LINEAR code = the minimum weight of a nonzero
    codeword (Theorem 38.4). codewords is a list of equal-length bit lists."""
    weights = [sum(c) for c in codewords if any(c)]   # skip the all-zero word
    return min(weights)


# Expected behavior (hand-derived):
#   hamming_encode([1, 0, 1, 1])        -> [0, 1, 1, 0, 0, 1, 1]
#       (d1=1,d2=0,d3=1,d4=1 -> p1=0, p2=1, p4=0)
#   flip position 5 -> [0, 1, 1, 0, 1, 1, 1]
#   hamming_decode([0, 1, 1, 0, 1, 1, 1]) -> ([1, 0, 1, 1], 5)   (error found and fixed)
#   hamming_distance([0,1,1,0,0,1,1], [0,1,1,0,1,1,1]) -> 1
#   syndrome(HAMMING74_H, [0, 1, 1, 0, 1, 1, 1]) -> [1, 0, 1]    (= position 5)
#   min_distance([[0, 0, 0], [1, 1, 1]]) -> 3                     (repetition code)

💡 Intuition: The Hamming code's magic is its addressing trick. The three parity checks produce a three-bit number — the syndrome — and that number, read in binary, is the position of the flipped bit (or zero if nothing flipped). No search, no comparison: the math points straight at the error. That is why the parity bits sit at positions 1, 2, and 4 — the powers of two whose binary representations are exactly the place values $s_1, s_2, s_4$.


How the modules fit together

You did not build ten disconnected scripts; you built one library whose pieces lean on each other. The dependencies run mostly downhill:

Module Depends on Anchor thread it advances
logic.py (standard library only)
sets.py (standard library only)
relations.py sets-as-pairs idea from sets.py
combinatorics.py (standard library math.comb)
recurrences.py (standard library only) Fibonacci (fib)
probability.py (standard library only)
numbertheory.py (standard library only) RSA (the engine)
crypto.py numbertheory.py (mod_inverse, mod_pow, random_prime) RSA (the payoff)
graphs.py (standard library only) social-network graphs
coding.py (standard library only)

Two patterns recur across the whole package and are worth naming, because they are the kind of idea that transfers far beyond this book:

  • Fast exponentiation by repeated squaring appears three times: in fib (matrices), in mod_pow (modular integers), and therefore inside every RSA operation. The same $O(\log n)$ idea — square the base, consume one bit of the exponent — works for anything you can multiply associatively.
  • Breadth-first search is a template, not a one-off. Plain bfs gives unweighted shortest paths; swap its queue for a priority queue and you have dijkstra; run it on a residual network and you have the augmenting-path search inside max_flow. Recognizing one algorithm as a specialization of another is exactly the abstraction skill this book set out to build.

That is the whole dmtoolkit. Every line of it is mathematics you proved before you coded — which is the real lesson. The tests pass because the theorems hold, and now you can read, write, and trust both.