Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions day5/演習3/tests/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,36 @@
import pytest
import time
import json
import os
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from models.main import DataLoader, ModelTester # パスはリポジトリ構成に合わせて変更


def load_model():
"""モデルをロードする関数"""
# day5/演習1 のモデルを使用
model_path = os.path.join("day5", "演習1", "models", "titanic_model.pkl")
with open(model_path, "rb") as f:
model = pickle.load(f)
return model
def test_model_inference_accuracy():
"""モデルの推論精度をテスト"""
# モデル読み込み
model = ModelTester.load_model("models/titanic_model.pkl")

# テストデータの読み込みと前処理
data = DataLoader.load_titanic_data("data/Titanic.csv")
X, y = DataLoader.preprocess_titanic_data(data)

def load_test_data():
"""テストデータをロードする関数"""
# day5/演習1 のデータを使用
data_path = os.path.join("day5", "演習1", "data", "test.csv")
# 精度評価
y_pred = model.predict(X)
accuracy = (y_pred == y).mean()

# データを読み込む
df = pd.read_csv(data_path)
# 閾値に基づくテスト
assert accuracy >= 0.75, f"Accuracy too low: {accuracy:.4f}"

# 前処理
if "Survived" in df.columns:
X = df.drop(["Survived"], axis=1)
y = df["Survived"]
else:
X = df
# ダミーのラベル
y = pd.Series([0, 1] * (len(df) // 2) + [0] * (len(df) % 2))

# カテゴリカル変数の処理
X = pd.get_dummies(X)
def test_model_inference_time():
"""モデルの推論時間をテスト"""
model = ModelTester.load_model("models/titanic_model.pkl")
data = DataLoader.load_titanic_data("data/Titanic.csv")
X, _ = DataLoader.preprocess_titanic_data(data)

return X, y
start = time.time()
_ = model.predict(X)
elapsed = time.time() - start

assert elapsed < 1.0, f"Inference took too long: {elapsed:.4f} sec"

def test_model_inference_accuracy():
"""モデルの推論精度をテスト"""
model = load_model()
X, y = load_test_data()
Loading