|
| 1 | +#!/usr/bin/env python |
| 2 | +# Created by "Thieu" at 10:33, 25/05/2025 ----------% |
| 3 | +# Email: nguyenthieu2102@gmail.com % |
| 4 | +# Github: https://github.com/thieu1995 % |
| 5 | +# --------------------------------------------------% |
| 6 | + |
| 7 | +import pytest |
| 8 | +import numpy as np |
| 9 | +from sklearn.datasets import make_classification |
| 10 | +from waveletml import MhaWnnClassifier |
| 11 | + |
| 12 | + |
| 13 | +def test_binary_classification_fit_predict_score(): |
| 14 | + X, y = make_classification(n_samples=120, n_features=6, n_classes=2, random_state=42) |
| 15 | + |
| 16 | + model = MhaWnnClassifier( |
| 17 | + size_hidden=5, |
| 18 | + optim="BaseGA", # Use Genetic Algorithm for diversity |
| 19 | + optim_params={"epoch": 20, "pop_size": 20}, |
| 20 | + obj_name="AS", # Accuracy Score |
| 21 | + seed=42, |
| 22 | + verbose=False |
| 23 | + ) |
| 24 | + |
| 25 | + model.fit(X, y) |
| 26 | + y_pred = model.predict(X) |
| 27 | + assert y_pred.shape == (X.shape[0],) |
| 28 | + assert np.all(np.isin(y_pred, [0, 1])) |
| 29 | + |
| 30 | + score = model.score(X, y) |
| 31 | + assert isinstance(score, float) |
| 32 | + assert 0 <= score <= 1 |
| 33 | + |
| 34 | + |
| 35 | +def test_multiclass_classification_fit_predict(): |
| 36 | + X, y = make_classification(n_samples=100, n_features=5, n_classes=3, n_informative=3, n_redundant=0, |
| 37 | + random_state=42) |
| 38 | + |
| 39 | + model = MhaWnnClassifier( |
| 40 | + size_hidden=6, |
| 41 | + optim="OriginalPSO", # Particle Swarm Optimization |
| 42 | + optim_params={"epoch": 20, "pop_size": 20}, |
| 43 | + obj_name="AS", |
| 44 | + seed=1, |
| 45 | + verbose=False |
| 46 | + ) |
| 47 | + |
| 48 | + model.fit(X, y) |
| 49 | + y_pred = model.predict(X) |
| 50 | + assert y_pred.shape == (X.shape[0],) |
| 51 | + assert set(np.unique(y_pred)).issubset(set(np.unique(y))) |
| 52 | + |
| 53 | + |
| 54 | +def test_predict_proba_output_shape(): |
| 55 | + X, y = make_classification(n_samples=80, n_features=4, n_classes=2, random_state=0) |
| 56 | + |
| 57 | + model = MhaWnnClassifier( |
| 58 | + optim="OriginalDE", |
| 59 | + optim_params={"epoch": 20, "pop_size": 20}, |
| 60 | + obj_name="AS", |
| 61 | + verbose=False |
| 62 | + ) |
| 63 | + |
| 64 | + model.fit(X, y) |
| 65 | + probs = model.predict_proba(X) |
| 66 | + assert isinstance(probs, np.ndarray) |
| 67 | + assert probs.shape == (X.shape[0], 1) or probs.shape[1] > 1 |
| 68 | + |
| 69 | + |
| 70 | +def test_invalid_obj_name_raises_error(): |
| 71 | + X, y = make_classification( |
| 72 | + n_samples=60, |
| 73 | + n_features=3, |
| 74 | + n_informative=2, |
| 75 | + n_redundant=0, |
| 76 | + n_repeated=0, |
| 77 | + n_classes=2, |
| 78 | + random_state=0 |
| 79 | + ) |
| 80 | + |
| 81 | + model = MhaWnnClassifier( |
| 82 | + optim="BaseGA", |
| 83 | + optim_params={"epoch": 20, "pop_size": 20}, |
| 84 | + obj_name="INVALID_METRIC", |
| 85 | + verbose=False |
| 86 | + ) |
| 87 | + |
| 88 | + with pytest.raises(ValueError, |
| 89 | + match="obj_name is not supported. Please check the library: permetrics to see the supported objective function."): |
| 90 | + model.fit(X, y) |
| 91 | + |
| 92 | + |
| 93 | +def test_predict_proba_invalid_task(): |
| 94 | + X, y = make_classification( |
| 95 | + n_samples=60, |
| 96 | + n_features=3, |
| 97 | + n_informative=2, |
| 98 | + n_redundant=0, |
| 99 | + n_repeated=0, |
| 100 | + n_classes=2, |
| 101 | + random_state=0 |
| 102 | + ) |
| 103 | + |
| 104 | + model = MhaWnnClassifier( |
| 105 | + optim="BaseGA", |
| 106 | + optim_params={"epoch": 25, "pop_size": 20}, |
| 107 | + obj_name="AS", |
| 108 | + verbose=False |
| 109 | + ) |
| 110 | + |
| 111 | + model.fit(X, y) |
| 112 | + model.task = "regression" # Force task to be incorrect |
| 113 | + with pytest.raises(ValueError, match="predict_proba is only available for classification tasks."): |
| 114 | + model.predict_proba(X) |
0 commit comments