Skip to content

Commit

Permalink
remove Pandas dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
4imothy committed Nov 9, 2024
1 parent 27d3e09 commit b2888f6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions bench/gather.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from collections import defaultdict
from pandas.compat import sys
import sys
from torch import nn
import time
import torch
import pandas as pd
import matplotlib.pyplot as plt
import ai3
import os
Expand Down Expand Up @@ -188,17 +187,24 @@ def table_save():
models_data[model] = {}
models_data[model].update(cudnn_models[model])

df = pd.DataFrame(models_data).transpose()
df = df.round(4)
columns = list(models_data[next(iter(models_data))].keys()) # Assuming all models have the same keys
row_labels = list(models_data.keys())
table_data = []

for model in row_labels:
row = [round(models_data[model].get(col, 0), 4) for col in columns] # Rounding values
table_data.append(row)

# Plotting
_, ax = plt.subplots(figsize=(10, 6))
ax.axis('off')

table = ax.table(cellText=df.values, colLabels=df.columns, # type: ignore
rowLabels=df.index, cellLoc='center', loc='center') # type: ignore
# Create the table
table = ax.table(cellText=table_data, colLabels=columns,
rowLabels=row_labels, cellLoc='center', loc='center')
table.auto_set_font_size(False)
table.set_fontsize(16)
table.auto_set_column_width(col=list(range(len(df.columns))))
table.auto_set_column_width(col=list(range(len(columns))))
table.scale(1, 2)

ax.annotate(
Expand Down
Binary file added bench/results/model_times.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed bench/results/model_times_table.png
Binary file not shown.

0 comments on commit b2888f6

Please sign in to comment.