-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
40 lines (32 loc) · 1.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
import numpy as np
import streamlit as st
from sklearn.metrics import plot_confusion_matrix, plot_roc_curve, plot_precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
@st.cache(persist=True)
def load_data():
data = pd.read_csv('data/mushrooms.csv')
label = LabelEncoder()
for col in data.columns:
data[col] = label.fit_transform(data[col])
return data
@st.cache(persist=True)
def split(df):
y = df.type
x = df.drop(columns=['type'])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=0)
return x_train, x_test, y_train, y_test
def plot_metrics(metrics_list, model, x_test, y_test, class_names):
if 'Confusion Matrix' in metrics_list:
st.subheader("Confusion Matrix")
plot_confusion_matrix(model, x_test, y_test, display_labels=class_names)
st.pyplot()
if 'ROC Curve' in metrics_list:
st.subheader("ROC Curve")
plot_roc_curve(model, x_test, y_test)
st.pyplot()
if 'Precision-Recall Curve' in metrics_list:
st.subheader("Precision-Recall Curve")
plot_precision_recall_curve(model, x_test, y_test)
st.pyplot()