Visualizing Math Concepts with Matplotlib

Seeing a concept is faster than reading about it. A loss curve tells you in one glance whether your model is training, overfitting, or stuck. A plot of a function shows you where its minimum is. Matplotlib is the standard plotting library for Python, and a handful of its functions cover 90% of what you'll need in AI.


Plotting Functions: The Basics

The core workflow is always the same: create an array of x-values with np.linspace(), compute y-values, and call plt.plot(). np.linspace(a, b, n) creates n evenly-spaced points between a and b.

Plotting a Loss Landscape

<pre><code class="language-python">import numpy as np import matplotlib.pyplot as plt x = np.linspace(-3, 3, 300) # A parabolic loss surface y_mse = x**2 y_sigmoid = 1 / (1 + np.exp(-x)) plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.plot(x, y_mse, color='royalblue') plt.title('MSE Loss: $f(x) = x^2$') plt.axhline(0, color='gray', lw=0.5) plt.subplot(1, 2, 2) plt.plot(x, y_sigmoid, color='tomato') plt.title('Sigmoid Activation') plt.tight_layout() plt.show() </pre>

Visualizing Gradient Descent

Plotting the path of gradient descent on a function helps you understand learning rates intuitively. A path that zigzags wildly means the learning rate is too high; one that barely moves means it's too low.

Tracking the Optimization Path

<pre><code class="language-python">x_val = 3.0 lr = 0.3 path = [x_val] for _ in range(20): grad = 2 * x_val # derivative of x^2 x_val -= lr * grad path.append(x_val) x_plot = np.linspace(-3.5, 3.5, 200) plt.plot(x_plot, x_plot**2, 'b-', label='f(x) = x²') plt.plot(path, [p**2 for p in path], 'ro--', label='GD path') plt.legend() plt.title('Gradient Descent on a Parabola') plt.show() </pre>

Plotting Distributions

Histograms and density plots let you inspect whether your data follows the distribution you expect — critical for diagnosing data quality issues before training.

Histogram of Samples

<pre><code class="language-python">samples = np.random.normal(loc=0, scale=1, size=5000) plt.hist(samples, bins=50, density=True, color='steelblue', edgecolor='white', alpha=0.8) plt.title('Histogram of Normal Samples') plt.xlabel('Value') plt.ylabel('Density') plt.show() </pre>