3D Line Plots
3D Line Plots¶
Plotting in three dimensions is actually very easy with matplotlib. This technique can be very helpful when you want to visualize something like a loss function for machine learning. To start with, we are going to build a 2D spiral. The math doesn’t matter for this case, just focus on what the X and Y data looks like.
In [1]:
#Start with a basic 2-D spiral
import math
import numpy as np
import matplotlib.pyplot as plt
#Math to build the spiral
phi = np.linspace(0,math.pi * 10, 1000)
r = phi / math.pi
X = np.cos(phi) * r
Y = np.sin(phi) * r
#Plot the spiral
plt.plot(X, Y)
plt.show()
Now, if we had a variable z which corresponded to height, we would need to use 3D plotting. Once again, the calculation of z does not really matter, focus on how we actually plot it. First, a figure needs to be created, then, we add a subplot with projection of ‘3d’, and this will allow us to show it.
In [2]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#Create a z-axis which is r^2
Z = r ** 2
#Create a figure
fig = plt.figure(figsize=(12,12))
#Add a subplot with 3D projection
ax = fig.add_subplot(111, projection='3d')
#Plot the spiral
ax.plot(X,Y,Z)
plt.show()
If you prefer to do a scatter plot, then calling ax.scatter instead will achieve this.
In [3]:
#You can also do a scatter plot instead
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X,Y,Z)
plt.show()