Case Study 2: Building an Automatic Differentiation Engine from Scratch

Overview

In this case study, we build a complete automatic differentiation engine from scratch using only NumPy. By the end, our engine will support scalar and vector operations, compute gradients automatically via reverse-mode differentiation, and be capable of training a simple neural network. This exercise demystifies what happens inside frameworks like PyTorch and JAX, transforming "autograd magic" into understandable engineering.

Motivation

Every deep learning framework — PyTorch, TensorFlow, JAX — relies on automatic differentiation at its core. Understanding how autodiff works is not just an academic exercise: it helps you debug gradient issues, design custom operations, understand memory trade-offs, and appreciate why certain architectures are easier to train than others.

We will build our engine in stages:

  1. Stage 1: A Scalar class for scalar-valued reverse-mode autodiff.
  2. Stage 2: Extending to support all common operations.
  3. Stage 3: Forward-mode autodiff using dual numbers for comparison.
  4. Stage 4: Training a neural network with our engine.

Stage 1: The Core Abstraction

The fundamental abstraction is a Scalar value that tracks its computational history. Each Scalar knows:

  • Its numerical value
  • Which operation created it
  • Which Scalar values were its inputs (its parents in the computational graph)
  • How to propagate gradients backward through itself
import numpy as np
from typing import Callable, Optional, Union


class Scalar:
    """A scalar value that supports reverse-mode automatic differentiation.

    Each Scalar maintains a reference to its children in the computational
    graph and a backward function that propagates gradients.

    Attributes:
        data: The numerical value.
        grad: The accumulated gradient after backward().
        label: Optional name for debugging/visualization.
    """

    def __init__(
        self,
        data: float,
        children: tuple = (),
        op: str = "",
        label: str = "",
    ) -> None:
        self.data = float(data)
        self.grad = 0.0
        self._backward: Callable[[], None] = lambda: None
        self._children = set(children)
        self._op = op
        self.label = label

    def __repr__(self) -> str:
        label_str = f" ({self.label})" if self.label else ""
        return f"Scalar(data={self.data:.4f}, grad={self.grad:.4f}{label_str})"

    # --- Arithmetic Operations ---

    def __add__(self, other: Union["Scalar", float, int]) -> "Scalar":
        other = other if isinstance(other, Scalar) else Scalar(other)
        out = Scalar(self.data + other.data, (self, other), "+")

        def _backward() -> None:
            self.grad += out.grad
            other.grad += out.grad

        out._backward = _backward
        return out

    def __radd__(self, other: Union[float, int]) -> "Scalar":
        return self + other

    def __mul__(self, other: Union["Scalar", float, int]) -> "Scalar":
        other = other if isinstance(other, Scalar) else Scalar(other)
        out = Scalar(self.data * other.data, (self, other), "*")

        def _backward() -> None:
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad

        out._backward = _backward
        return out

    def __rmul__(self, other: Union[float, int]) -> "Scalar":
        return self * other

    def __neg__(self) -> "Scalar":
        return self * (-1)

    def __sub__(self, other: Union["Scalar", float, int]) -> "Scalar":
        return self + (-other)

    def __rsub__(self, other: Union[float, int]) -> "Scalar":
        return (-self) + other

    def __truediv__(self, other: Union["Scalar", float, int]) -> "Scalar":
        return self * other ** (-1)

    def __rtruediv__(self, other: Union[float, int]) -> "Scalar":
        return Scalar(other) * self ** (-1)

    def __pow__(self, exponent: Union[float, int]) -> "Scalar":
        assert isinstance(exponent, (int, float)), "Only scalar exponents supported"
        out = Scalar(self.data ** exponent, (self,), f"**{exponent}")

        def _backward() -> None:
            self.grad += exponent * (self.data ** (exponent - 1)) * out.grad

        out._backward = _backward
        return out

    # --- Activation Functions ---

    def exp(self) -> "Scalar":
        """Compute e^self."""
        out = Scalar(np.exp(self.data), (self,), "exp")

        def _backward() -> None:
            self.grad += out.data * out.grad  # d(e^x)/dx = e^x

        out._backward = _backward
        return out

    def log(self) -> "Scalar":
        """Compute ln(self)."""
        out = Scalar(np.log(self.data), (self,), "log")

        def _backward() -> None:
            self.grad += (1.0 / self.data) * out.grad

        out._backward = _backward
        return out

    def relu(self) -> "Scalar":
        """Compute max(0, self)."""
        out = Scalar(max(0, self.data), (self,), "relu")

        def _backward() -> None:
            self.grad += (self.data > 0) * out.grad

        out._backward = _backward
        return out

    def sigmoid(self) -> "Scalar":
        """Compute 1 / (1 + exp(-self))."""
        s = 1.0 / (1.0 + np.exp(-self.data))
        out = Scalar(s, (self,), "sigmoid")

        def _backward() -> None:
            self.grad += s * (1 - s) * out.grad

        out._backward = _backward
        return out

    def tanh(self) -> "Scalar":
        """Compute tanh(self)."""
        t = np.tanh(self.data)
        out = Scalar(t, (self,), "tanh")

        def _backward() -> None:
            self.grad += (1 - t ** 2) * out.grad

        out._backward = _backward
        return out

    # --- Backward Pass ---

    def backward(self) -> None:
        """Perform reverse-mode autodiff starting from this node.

        Computes gradients for all ancestor nodes in the computational
        graph using topological ordering.
        """
        # Build topological order
        topo: list[Scalar] = []
        visited: set[Scalar] = set()

        def _build_topo(node: Scalar) -> None:
            if node not in visited:
                visited.add(node)
                for child in node._children:
                    _build_topo(child)
                topo.append(node)

        _build_topo(self)

        # Seed gradient and propagate backward
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()

    def zero_grad(self) -> None:
        """Reset gradient to zero."""
        self.grad = 0.0

Stage 2: Verification with Numerical Gradients

Before building on top of our engine, we rigorously verify that it computes correct gradients.

def numerical_gradient_check(
    f: Callable[..., Scalar],
    inputs: list[Scalar],
    h: float = 1e-6,
    tolerance: float = 1e-5,
) -> bool:
    """Verify autodiff gradients against numerical gradients.

    Args:
        f: A function that takes Scalar inputs and returns a Scalar.
        inputs: List of input Scalar values.
        h: Step size for finite differences.
        tolerance: Maximum allowed relative error.

    Returns:
        True if all gradients pass the check.
    """
    # Compute analytical gradients
    for inp in inputs:
        inp.grad = 0.0
    output = f(*inputs)
    output.backward()
    analytical_grads = [inp.grad for inp in inputs]

    # Compute numerical gradients
    numerical_grads = []
    for i, inp in enumerate(inputs):
        original = inp.data

        inp.data = original + h
        # Reset grads before re-evaluation
        for x in inputs:
            x.grad = 0.0
        f_plus = f(*inputs).data

        inp.data = original - h
        for x in inputs:
            x.grad = 0.0
        f_minus = f(*inputs).data

        inp.data = original  # restore
        numerical_grads.append((f_plus - f_minus) / (2 * h))

    # Compare
    all_passed = True
    for i, (a, n) in enumerate(zip(analytical_grads, numerical_grads)):
        denom = abs(a) + abs(n) + 1e-15
        rel_error = abs(a - n) / denom
        status = "PASS" if rel_error < tolerance else "FAIL"
        if rel_error >= tolerance:
            all_passed = False
        print(
            f"  Input {i}: analytical={a:.6f}, numerical={n:.6f}, "
            f"rel_error={rel_error:.2e} [{status}]"
        )

    return all_passed


# Test suite
print("Test 1: f(x, y) = x * y + x^2")
x, y = Scalar(2.0), Scalar(3.0)
assert numerical_gradient_check(lambda x, y: x * y + x ** 2, [x, y])

print("\nTest 2: f(x) = sigmoid(x)")
x = Scalar(0.5)
assert numerical_gradient_check(lambda x: x.sigmoid(), [x])

print("\nTest 3: f(x, y) = log(exp(x) + exp(y))  (LogSumExp)")
x, y = Scalar(1.0), Scalar(2.0)
assert numerical_gradient_check(
    lambda x, y: (x.exp() + y.exp()).log(), [x, y]
)

print("\nTest 4: f(x) = relu(x - 1) * relu(1 - x)")
x = Scalar(0.5)
assert numerical_gradient_check(
    lambda x: (x - 1).relu() * (1 - x).relu(), [x]
)

Stage 3: Forward-Mode Autodiff with Dual Numbers

For comparison, we implement forward-mode autodiff using dual numbers.

class DualNumber:
    """A dual number for forward-mode automatic differentiation.

    A dual number a + b*epsilon carries both a value (a) and a derivative (b).
    The key property is that epsilon^2 = 0.

    Attributes:
        value: The real part (function value).
        deriv: The dual part (derivative).
    """

    def __init__(self, value: float, deriv: float = 0.0) -> None:
        self.value = float(value)
        self.deriv = float(deriv)

    def __repr__(self) -> str:
        return f"Dual({self.value:.4f} + {self.deriv:.4f}e)"

    def __add__(self, other: Union["DualNumber", float, int]) -> "DualNumber":
        if not isinstance(other, DualNumber):
            other = DualNumber(other)
        return DualNumber(
            self.value + other.value,
            self.deriv + other.deriv,
        )

    def __radd__(self, other: Union[float, int]) -> "DualNumber":
        return self + other

    def __mul__(self, other: Union["DualNumber", float, int]) -> "DualNumber":
        if not isinstance(other, DualNumber):
            other = DualNumber(other)
        return DualNumber(
            self.value * other.value,
            self.value * other.deriv + self.deriv * other.value,
        )

    def __rmul__(self, other: Union[float, int]) -> "DualNumber":
        return self * other

    def __neg__(self) -> "DualNumber":
        return DualNumber(-self.value, -self.deriv)

    def __sub__(self, other: Union["DualNumber", float, int]) -> "DualNumber":
        return self + (-other)

    def __rsub__(self, other: Union[float, int]) -> "DualNumber":
        return (-self) + other

    def __truediv__(self, other: Union["DualNumber", float, int]) -> "DualNumber":
        if not isinstance(other, DualNumber):
            other = DualNumber(other)
        return DualNumber(
            self.value / other.value,
            (self.deriv * other.value - self.value * other.deriv) / other.value ** 2,
        )

    def __rtruediv__(self, other: Union[float, int]) -> "DualNumber":
        return DualNumber(other) / self

    def __pow__(self, exponent: Union[float, int]) -> "DualNumber":
        return DualNumber(
            self.value ** exponent,
            exponent * self.value ** (exponent - 1) * self.deriv,
        )

    def exp(self) -> "DualNumber":
        """Compute e^self."""
        e = np.exp(self.value)
        return DualNumber(e, e * self.deriv)

    def log(self) -> "DualNumber":
        """Compute ln(self)."""
        return DualNumber(
            np.log(self.value),
            self.deriv / self.value,
        )

    def sin(self) -> "DualNumber":
        """Compute sin(self)."""
        return DualNumber(
            np.sin(self.value),
            np.cos(self.value) * self.deriv,
        )

    def cos(self) -> "DualNumber":
        """Compute cos(self)."""
        return DualNumber(
            np.cos(self.value),
            -np.sin(self.value) * self.deriv,
        )


# Compare forward-mode and reverse-mode
print("=== Comparing Forward-Mode and Reverse-Mode ===")
print("\nFunction: f(x, y) = x^2 * y + exp(x * y) at (1, 2)")

# Forward mode: df/dx (seed x with deriv=1)
x_dual = DualNumber(1.0, 1.0)  # seed for df/dx
y_dual = DualNumber(2.0, 0.0)
result = x_dual ** 2 * y_dual + (x_dual * y_dual).exp()
print(f"Forward mode df/dx = {result.deriv:.6f}")

# Forward mode: df/dy (seed y with deriv=1)
x_dual = DualNumber(1.0, 0.0)
y_dual = DualNumber(2.0, 1.0)  # seed for df/dy
result = x_dual ** 2 * y_dual + (x_dual * y_dual).exp()
print(f"Forward mode df/dy = {result.deriv:.6f}")

# Reverse mode: both gradients in one pass
x_rev = Scalar(1.0, label="x")
y_rev = Scalar(2.0, label="y")
f_rev = x_rev ** 2 * y_rev + (x_rev * y_rev).exp()
f_rev.backward()
print(f"Reverse mode df/dx = {x_rev.grad:.6f}")
print(f"Reverse mode df/dy = {y_rev.grad:.6f}")

Stage 4: Training a Neural Network

The ultimate test: can our autodiff engine train a neural network?

class Neuron:
    """A single neuron with weights, bias, and optional activation.

    Args:
        n_inputs: Number of input features.
        activation: Activation function name ('relu', 'sigmoid', 'linear').
        seed: Random seed.
    """

    def __init__(
        self, n_inputs: int, activation: str = "relu", seed: Optional[int] = None
    ) -> None:
        if seed is not None:
            np.random.seed(seed)
        self.weights = [Scalar(np.random.randn() * 0.5) for _ in range(n_inputs)]
        self.bias = Scalar(0.0)
        self.activation = activation

    def __call__(self, inputs: list[Scalar]) -> Scalar:
        """Compute the neuron output.

        Args:
            inputs: List of Scalar inputs.

        Returns:
            Scalar output after activation.
        """
        # Weighted sum
        total = self.bias
        for w, x in zip(self.weights, inputs):
            total = total + w * x

        # Activation
        if self.activation == "relu":
            return total.relu()
        elif self.activation == "sigmoid":
            return total.sigmoid()
        elif self.activation == "tanh":
            return total.tanh()
        else:
            return total

    def parameters(self) -> list[Scalar]:
        """Return all trainable parameters."""
        return self.weights + [self.bias]


class Layer:
    """A fully connected layer of neurons.

    Args:
        n_inputs: Number of input features.
        n_outputs: Number of neurons in this layer.
        activation: Activation function for all neurons.
    """

    def __init__(
        self, n_inputs: int, n_outputs: int, activation: str = "relu"
    ) -> None:
        self.neurons = [
            Neuron(n_inputs, activation=activation) for _ in range(n_outputs)
        ]

    def __call__(self, inputs: list[Scalar]) -> list[Scalar]:
        """Compute layer output.

        Args:
            inputs: List of Scalar inputs.

        Returns:
            List of Scalar outputs, one per neuron.
        """
        return [neuron(inputs) for neuron in self.neurons]

    def parameters(self) -> list[Scalar]:
        """Return all trainable parameters."""
        params = []
        for neuron in self.neurons:
            params.extend(neuron.parameters())
        return params


class MLP:
    """A multi-layer perceptron built from our autodiff engine.

    Args:
        layer_sizes: List of layer sizes including input dimension.
            E.g., [2, 8, 4, 1] creates a network with input dim 2,
            two hidden layers of size 8 and 4, and output dim 1.
        activations: List of activation functions for each layer
            (excluding input). Defaults to ReLU for hidden, linear
            for output.
    """

    def __init__(
        self,
        layer_sizes: list[int],
        activations: Optional[list[str]] = None,
    ) -> None:
        if activations is None:
            activations = ["relu"] * (len(layer_sizes) - 2) + ["linear"]

        self.layers = []
        for i in range(len(layer_sizes) - 1):
            self.layers.append(
                Layer(layer_sizes[i], layer_sizes[i + 1], activations[i])
            )

    def __call__(self, inputs: list[float]) -> list[Scalar]:
        """Forward pass through the network.

        Args:
            inputs: List of input values (plain floats).

        Returns:
            List of Scalar outputs.
        """
        x = [Scalar(v) if not isinstance(v, Scalar) else v for v in inputs]
        for layer in self.layers:
            x = layer(x)
        return x

    def parameters(self) -> list[Scalar]:
        """Return all trainable parameters."""
        params = []
        for layer in self.layers:
            params.extend(layer.parameters())
        return params

    def zero_grad(self) -> None:
        """Reset all gradients to zero."""
        for p in self.parameters():
            p.grad = 0.0


# Training example: learn the XOR function
print("=== Training on XOR ===")

# XOR dataset
X_xor = [[0, 0], [0, 1], [1, 0], [1, 1]]
y_xor = [0, 1, 1, 0]

# Create network: 2 inputs, hidden layer of 8, output of 1
np.random.seed(42)
model = MLP([2, 8, 1], activations=["relu", "sigmoid"])

learning_rate = 0.1
losses = []

for epoch in range(500):
    # Compute loss (binary cross-entropy)
    total_loss = Scalar(0.0)
    for xi, yi in zip(X_xor, y_xor):
        pred = model(xi)[0]
        # Binary cross-entropy: -[y*log(p) + (1-y)*log(1-p)]
        eps = 1e-7
        pred_clipped = pred * (1 - 2 * eps) + eps  # approximate clipping
        if yi == 1:
            loss_i = -(pred_clipped.log())
        else:
            loss_i = -((1 - pred_clipped).log())
        total_loss = total_loss + loss_i

    total_loss = total_loss * (1.0 / len(X_xor))

    # Backward pass
    model.zero_grad()
    total_loss.backward()

    # SGD update
    for p in model.parameters():
        p.data -= learning_rate * p.grad

    losses.append(total_loss.data)

    if epoch % 100 == 0:
        print(f"Epoch {epoch:3d}: loss = {total_loss.data:.4f}")

# Final predictions
print("\nFinal predictions:")
for xi, yi in zip(X_xor, y_xor):
    pred = model(xi)[0]
    print(f"  Input {xi} -> {pred.data:.4f} (target: {yi})")

Key Insights from Building the Engine

1. The Backward Function Pattern

Every operation creates a new Scalar and defines a closure (_backward) that encodes how to propagate gradients backward. This pattern -- create node, define backward -- is the fundamental building block of autodiff.

2. Gradient Accumulation

Notice the += in the backward functions (e.g., self.grad += out.grad). This is critical: when a value is used in multiple operations, its gradient is the sum of the gradients flowing back from each use. Using = instead of += would silently drop gradient contributions.

3. Topological Sorting

The backward pass must visit nodes in reverse topological order (outputs before inputs) to ensure that all gradient contributions are accumulated before propagation continues. Our backward() method builds this ordering explicitly.

4. Forward Mode Is Simpler but Less Efficient

The dual number implementation is notably simpler — no graph construction, no backward pass. But for a function with n inputs, forward mode requires n passes while reverse mode requires just one. For neural networks with millions of parameters, this difference is decisive.

5. Memory vs. Computation

Reverse mode stores the entire computational graph in memory. For a deep network with many operations, this can be substantial. This is why techniques like gradient checkpointing (recomputing forward values during the backward pass) are important for training very large models.

Extensions and Challenges

To take this engine further:

  1. Add vector support: Replace scalar operations with NumPy array operations, implementing backward passes for matrix multiply, reshape, and broadcasting.

  2. Implement more operations: Convolution, batch normalization, attention — each requires a custom backward pass.

  3. Add gradient checkpointing: Implement a mechanism to selectively recompute forward values instead of storing them.

  4. Optimize memory: Implement in-place operations and garbage collection for nodes no longer needed.

  5. Compare with PyTorch: Verify that your engine produces the same gradients as PyTorch's autograd on identical computations.

The complete code for this case study is available in code/case-study-code.py.