As I try to educate myself on the new hotness of large language models, I ran across this post and associated chart and decided I wanted to see if I could recreate that chart programmatically with matplotlib:

A chart on LLMs and the datasets on which they were trained

I got “mostly” there. Here’s what I did.

Step 1: assemble the data

Is this data–the models, what datasets they trained on, and how big those datasets were–already assembled somewhere on the internet? I just recreated it by hand from looking at the chart:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


data = [{'model': 'GPT-1', 'dataset':'Wikipedia', 'size_gb':None}, 
        {'model': 'GPT-1', 'dataset':'Books', 'size_gb':5}, 
        {'model': 'GPT-1', 'dataset':'Academic journals', 'size_gb':None}, 
        {'model': 'GPT-1', 'dataset':'Reddit links', 'size_gb':None}, 
        {'model': 'GPT-1', 'dataset':'CC', 'size_gb':None}, 
        {'model': 'GPT-1', 'dataset':'Other', 'size_gb':None}, 
        {'model': 'GPT-2', 'dataset':'Wikipedia', 'size_gb':None}, 
        {'model': 'GPT-2', 'dataset':'Books', 'size_gb':None}, 
        {'model': 'GPT-2', 'dataset':'Academic journals', 'size_gb':None}, 
        {'model': 'GPT-2', 'dataset':'Reddit links', 'size_gb':40}, 
        {'model': 'GPT-2', 'dataset':'CC', 'size_gb':None}, 
        {'model': 'GPT-2', 'dataset':'Other', 'size_gb':None}, 
        {'model': 'GPT-3', 'dataset':'Wikipedia', 'size_gb':11}, 
        {'model': 'GPT-3', 'dataset':'Books', 'size_gb':21}, 
        {'model': 'GPT-3', 'dataset':'Academic journals', 'size_gb':101}, 
        {'model': 'GPT-3', 'dataset':'Reddit links', 'size_gb':50}, 
        {'model': 'GPT-3', 'dataset':'CC', 'size_gb':570}, 
        {'model': 'GPT-3', 'dataset':'Other', 'size_gb':None}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'Wikipedia', 'size_gb':6}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'Books', 'size_gb':118}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'Academic journals', 'size_gb':244}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'Reddit links', 'size_gb':63}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'CC', 'size_gb':227}, 
        {'model': 'GPT-J/GPT-NeoX-20B', 'dataset':'Other', 'size_gb':167}, 
        {'model': 'Megatron-11B', 'dataset':'Wikipedia', 'size_gb':11}, 
        {'model': 'Megatron-11B', 'dataset':'Books', 'size_gb':5}, 
        {'model': 'Megatron-11B', 'dataset':'Academic journals', 'size_gb':None}, 
        {'model': 'Megatron-11B', 'dataset':'Reddit links', 'size_gb':38}, 
        {'model': 'Megatron-11B', 'dataset':'CC', 'size_gb':107}, 
        {'model': 'Megatron-11B', 'dataset':'Other', 'size_gb':None}, 
        {'model': 'MT-NLG', 'dataset':'Wikipedia', 'size_gb':6}, 
        {'model': 'MT-NLG', 'dataset':'Books', 'size_gb':118}, 
        {'model': 'MT-NLG', 'dataset':'Academic journals', 'size_gb':77}, 
        {'model': 'MT-NLG', 'dataset':'Reddit links', 'size_gb':63}, 
        {'model': 'MT-NLG', 'dataset':'CC', 'size_gb':983}, 
        {'model': 'MT-NLG', 'dataset':'Other', 'size_gb':127}, 
        {'model': 'Gopher', 'dataset':'Wikipedia', 'size_gb':12}, 
        {'model': 'Gopher', 'dataset':'Books', 'size_gb':2100}, 
        {'model': 'Gopher', 'dataset':'Academic journals', 'size_gb':164}, 
        {'model': 'Gopher', 'dataset':'Reddit links', 'size_gb':None}, 
        {'model': 'Gopher', 'dataset':'CC', 'size_gb':3450}, 
        {'model': 'Gopher', 'dataset':'Other', 'size_gb':4823}, 
        {'model': 'GPT-4', 'dataset':'Wikipedia', 'size_gb':None}, 
        {'model': 'GPT-4', 'dataset':'Books', 'size_gb':None}, 
        {'model': 'GPT-4', 'dataset':'Academic journals', 'size_gb':None}, 
        {'model': 'GPT-4', 'dataset':'Reddit links', 'size_gb':None}, 
        {'model': 'GPT-4', 'dataset':'CC', 'size_gb':None}, 
        {'model': 'GPT-4', 'dataset':'Other', 'size_gb':None}]

df = pd.DataFrame(data).fillna(0)
df.head()

Note that I converted all NaN values to 0: bad things happen when you try to chart a NaN.

Step 2: code my chart

I borrowed pretty heavily from matplotlib’s 3d bar chart examples. Some points to consider:

  • The datasets are listed along the X axis. In a 3d bar chart, 0 on the X axis is at the furthest left and the values increase as the axis intersects the Y axis. To replicate the chart, then, the “Other” dataset would fall at around point 0 while the “Wikipedia” dataset would fall at around point 5. However, I assembled the data in my dataframe starting with “Wikipedia” and continuing on up to “Other”. So, when it came time for me to chart the datasets, I had to reverse their order with the handy negative index slicing. I also had to reverse the colors associated with the datasets, too.
  • I’m always learning new things about Python with these types of exercises. How cool is the Pandas iat function? (Not to mention the endless supply of numpy functions I’ve never heard of before.)
  • In my dataframe, my X and Y values are categories–datasets and models, respectively. However, matplotlib does not like plotting categories, so, I had to use placeholder index/integer values and then later replace those tick labels with the real categories.
  • Setting the width and depth to 0.9 helps to creating padding between the bars.
  • I had to play with the rotation property of my tick labels to get them to sort of align properly to the chart. I did this by eye but I wonder if there’s a better way to align them exactly?

So, here’s my code:

fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=(6, 6))

models = df.model.unique().tolist()
datasets = df.dataset.unique().tolist()[::-1]
top = [df.loc[(df.model==m) & (df.dataset==d), 'size_gb'].iat[0] for m in models for d in datasets]
bottom = np.zeros_like(top)

# bar3d doesn't seem to like categories for x and y values
_x, _y = np.meshgrid(np.arange(len(datasets)), np.arange(len(models)))
x, y = _x.ravel(), _y.ravel()
width = depth = 0.9  # allows a little bit of padding between bars

ds_colors = ['greenyellow','fuchsia','turquoise','orangered','dodgerblue','goldenrod'][::-1]
#colors = list(np.array([[c]*len(models) for c in ds_colors]).flat)
colors = ds_colors * len(models)

_ = ax.bar3d(x, y, bottom, width, depth, top, color=colors, shade=True)
_ = ax.set_yticklabels(['', ''] + models + [''], rotation=-20, ha='left')
_ = ax.set_xticklabels([''] + datasets + ['', ''], rotation=50)

# annotate the size numbers onto the bars
for x1, y1, z1 in zip(x, y, top):
    if z1 > 0:
        _ = ax.text(x1, y1, z1, str(int(z1)), horizontalalignment='left', verticalalignment='bottom')
My attempt at the LLM chart

Ok. There’s still a lot to be done here. The Z axis and Z gridlines should go. The size annotations could be better aligned. The scale probably needs to be recalculated so that the 4823 value isn’t hiding all the other values. All the 0 length bars should disappear altogether. And a legend might be nice.

Anyway, I think I’ve accomplished the crux of what I set out to do, so I’ll leave it there. Hope this helps with some of your 3d charting endeavors!