Counterfactuals demonstration and computation for tabular

This commit is contained in:
atla8167 2024-07-01 18:17:14 +03:00
parent 6047fad7ce
commit 94c682362f
21 changed files with 832 additions and 82 deletions

Binary file not shown.

Binary file not shown.

@ -0,0 +1,54 @@
from base import views
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
from django_plotly_dash import DjangoDash
import plotly.express as px
import pandas as pd
from sklearn.manifold import TSNE
import json
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = DjangoDash("Tsne", external_stylesheets=external_stylesheets)
excel_file_name_preprocessed = "breast-cancer_preprocessed.csv"
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
preprocess_df.drop(["id"], axis=1, inplace=True)
features = preprocess_df.loc[:, :"compactness_se"]
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(features)
tsne = px.scatter(
projections,
x=0,
y=1,
color=preprocess_df.diagnosis.astype(str),
labels={"color": "diagnosis"},
)
tsne.update_layout(clickmode="event+select")
app.layout = html.Div(
[
dcc.Graph(id="scatter-plot", config={"displayModeBar": False}, figure=tsne),
html.Div(
className="row",
children=[
html.Div(
[
dcc.Input(
id="click-data",
type="hidden",
className="three columns",
)
]
)
],
),
]
)
@app.callback(Output("click-data", "children"), Input("scatter-plot", "clickData"))
def display_click_data(clickData, request):
# dictionary of list of dictionary ???
views.counterfactuals(clickData, request)
return json.dumps(clickData, indent=2)

@ -0,0 +1,16 @@
# Generated by Django 4.2.13 on 2024-06-17 08:51
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("base", "0002_document_delete_upload"),
]
operations = [
migrations.DeleteModel(
name="Document",
),
]

@ -288,12 +288,13 @@ nav input {
height: min-content;
max-height: 40%;
max-width: 35%;
border: 1px solid #bbb;
padding: 5px;
}
.dataframe {
font-size: 9pt;
font-family: Arial;
border-collapse: collapse;
font-size: 0.9em;
}
@ -306,6 +307,13 @@ nav input {
.dataframe td {
padding: 12px 15px;
text-align: left;
border: black;
border-collapse: separate;
}
.dataframe .clickedrow th,
.dataframe .clickedrow td {
background-color: #c6bdf8;
}
.dataframe tbody tr {
@ -319,31 +327,36 @@ nav input {
.dataframe tbody tr:last-of-type {
border-bottom: 2px solid #009879;
}
.dataframe tbody tr:hover {
background-color: #e8e5f9;
}
/* plotly toolbar */
.modebar {
display: none !important;
}
/* button */
/* CSS */
.button-6 {
align-items: center;
background-color: #FFFFFF;
background-color: #ffffff;
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: .25rem;
border-radius: 0.25rem;
box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0;
box-sizing: border-box;
color: rgba(0, 0, 0, 0.85);
cursor: pointer;
display: inline-flex;
font-family: system-ui,-apple-system,system-ui,"Helvetica Neue",Helvetica,Arial,sans-serif;
font-family: system-ui, -apple-system, system-ui, "Helvetica Neue", Helvetica,
Arial, sans-serif;
font-size: 16px;
font-weight: 600;
justify-content: center;
line-height: 1.25;
margin: 0;
min-height: 3rem;
padding: calc(.875rem - 1px) calc(1.5rem - 1px);
padding: calc(0.875rem - 1px) calc(1.5rem - 1px);
position: relative;
text-decoration: none;
transition: all 250ms;
@ -366,7 +379,7 @@ nav input {
}
.button-6:active {
background-color: #F0F0F1;
background-color: #f0f0f1;
border-color: rgba(0, 0, 0, 0.15);
box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px;
color: rgba(0, 0, 0, 0.65);
@ -379,7 +392,7 @@ nav input {
display: flex;
flex-wrap: wrap;
border-radius: 0.5rem;
background-color: #EEE;
background-color: #eee;
box-sizing: border-box;
box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06);
padding: 0.25rem;
@ -403,12 +416,32 @@ nav input {
justify-content: center;
border-radius: 0.5rem;
border: none;
padding: .5rem 0;
padding: 0.5rem 0;
color: rgba(51, 65, 85, 1);
transition: all .15s ease-in-out;
transition: all 0.15s ease-in-out;
}
.radio-inputs .radio input:checked + .name {
background-color: #fff;
font-weight: 600;
}
}
.loader {
width: 48px;
height: 48px;
border: 5px solid #fff;
border-bottom-color: #ff3d00;
border-radius: 50%;
display: inline-block;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

@ -0,0 +1,108 @@
$(document).ready(function () {
// When click on graph (tsne)
// iterate through the elements and find the one with opacity 1
// the rest have opacity 0.2 when a button is clicked...
// Case 1: They are all opacity 1!
// In that case we cannot check all of them
// it would be inneficient,
// though if we just find one more we can determine
// that all the rest are also opacity: 1 and thus
// there was no click on a specific point
// Case 2: There was actually a click on a point!
// In that case we need to find that point. Poltly
// does not save all the points on one <g> element
// Instead on the particular example it would use 2
// <g> elements with multiple <points> in them.
// We iterate through all the collections and all the points
// until we find the point element with opacity 1.
// That way feels inefficient but in reality rarely will
// the last ever point of the last collection be selected
// and even then the complexity is just O(n). Not great
// but not bad either. It would help to be able to determine
// if a point is in collection i or j but I am not there
// yet.
document.getElementById('tsne').addEventListener('click', function () {
// remove original and counterfactual columns
if (document.getElementById("og_cf_row") && document.getElementById("og_cf_headers")) {
document.getElementById("og_cf_row").style.display = "none";
document.getElementById("og_cf_headers").style.display = "none";
}
var tsne = document.getElementById('tsne');
var points_array = tsne.getElementsByClassName("points")
var counter_break = 0;
var final_position = 0;
var j;
var i;
for (i = 0; i < points_array.length; i++) {
var points = points_array[i]
var children = points.children
if (getComputedStyle(children[0]).opacity == 1 && getComputedStyle(children[1]).opacity == 1) {
break;
}
console.log(children.length)
for (j = 0; j < children.length; j++) {
var st = children[j]
opacity = getComputedStyle(st).opacity
if (opacity == 1) {
console.log("i: ", i)
console.log("J: ", j)
console.log("children.length: ", children.length)
counter_break++;
final_position += j;
console.log("final_position: ", final_position)
break;
}
}
if (counter_break) {
break;
}
final_position += j;
}
if (counter_break == 0) {
$("#cfrow").remove();
document.getElementById("cftable").style.display = "none";
document.getElementById("cfbtn").style.display = "none";
} else if (counter_break == 1) {
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': final_position, 'action': "click_graph" },
success: function (ret) {
row = JSON.parse(ret)
row = row["row"]
var tb = document.createElement('table');
tb.innerHTML = row.trim();
tb.setAttribute("id", "cfrow");
tb.setAttribute("class", "dataframe")
if (document.getElementById("cfrow")) {
$("#cfrow").remove();
document.getElementById("cftable").style.display = "none";
}
$("#cftable").append(tb);
document.getElementById("cftable").style.display = "block";
document.getElementById("cfbtn").style.display = "block";
},
error: function (row) {
console.log("it didnt work");
}
});
} else if (counter_break > 1) {
console.log("All opacity 1")
}
});
});

@ -0,0 +1,30 @@
$(document).ready(function () {
//Highlight clicked row
$('.dataframe td').on('click', function () {
// Remove previous highlight class
$(this).closest('.dataframe').find('tr.clickedrow').removeClass('clickedrow');
// add highlight to the parent tr of the clicked td
$(this).closest('tr').addClass("clickedrow");
// fetch row
var $row = this.closest('tr')
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': $row.innerHTML.trim().split("\n").map(line => line.trim()).join("\n") },
success: function (id) {
var myPlot = document.getElementById('tsne');
var paths = myPlot.querySelector('.points');
var path = paths.children[id];
path.setAttribute("style", "opacity: 1; stroke-width: 0px; fill: green; fill-opacity: 1; scale: 2.0;");
console.log(path)
},
error: function (id) {
console.log("it didnt work");
}
});
});
});

@ -0,0 +1,27 @@
$(document).ready(function () {
//Highlight clicked row
document.getElementById('reset_div').addEventListener('click', function () {
// on click reset the graph
// if reset button exists, hide it
document.getElementById("reset_div").style.display = "none";
document.getElementById("loader_cf").style.display = "block";
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: {'action': "reset_graph" },
success: function (ret) {
document.getElementById("loader_cf").style.display = "none";
ret = JSON.parse(ret)
fig = ret["fig"]
document.getElementById("tsne").innerHTML = "";
$("#tsne").append(fig)
document.getElementById("reset_div").style.display = "none";
},
error: function (ret) {
}
});
});
});

@ -0,0 +1,62 @@
$(document).ready(function () {
var mySelect = document.getElementById('cfrow_id');
mySelect.onchange = (event) => {
var rowId = event.target.value;
// inputText is the row id of the counter factual
// Send ajax request to backend and inquire for the
// respective counterfactual. When it is acquired
// replace current counterfactual view with that...
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
document.getElementById("loader_cf").style.display = "block";
// if reset button exists, hide it
if (document.getElementById("reset_div"))
document.getElementById("reset_div").style.display = "none";
$("#show_image").show()
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': rowId, 'action': "counterfactual_select" },
success: function (ret) {
// stop loader
document.getElementById("loader_cf").style.display = "none";
// add <input type="reset" value="Reset">
if (!document.getElementById("reset")) {
var reset = document.createElement('input')
reset.setAttribute("id", "reset")
reset.setAttribute("type", "reset")
reset.setAttribute("value", "Reset")
$("#reset_div").append(reset)
}
// if reset button has been created once just display it
document.getElementById("reset_div").style.display = "block";
var tb = document.createElement('table');
ret = JSON.parse(ret)
row = ret["row"]
tb.innerHTML = row.trim();
tb.setAttribute("id", "counterfactual_selected");
tb.setAttribute("class", "dataframe")
if (document.getElementById("counterfactual_selected")) {
$("#counterfactual_selected").remove();
}
$("#counterfactual").append(tb);
document.getElementById("counterfactual_selected").style.display = "block";
fig = ret["fig"]
document.getElementById("tsne").innerHTML = "";
$("#tsne").append(fig)
},
error: function (row) {
console.log("it didnt work");
}
});
}
});

@ -3,7 +3,7 @@ function setScreen() {
window.scrollTo(0, yScreen);
}
function setScroll() {
var yScroll = window.pageYOffset;
var yScroll = window.scrollY;
localStorage.setItem("yPos", yScroll);
}
function clearScreen() {

@ -2,6 +2,26 @@
{% block content %}
{% load static %}
<!-- <div class="container-fluid">
<div class="row">
<div class="col-sm d-flex justify-content-center" >
<div class="dataset">
<div class="form-check">
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault1">
<label class="form-check-label" for="flexRadioDefault1">
Breast Cancer Dataset
</label>
</div>
<div class="form-check">
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault2">
<label class="form-check-label" for="flexRadioDefault2">
Other
</label>
</div>
</div>
</div>
</div>
</div> -->
<div class="container-fluid">
<div class="row">
<div class="col-sm d-flex justify-content-center">
@ -68,8 +88,10 @@
</div>
</div>
</div>
<br>
<br>
<div class="container-fluid">
<div class="row">
<section>
@ -112,25 +134,20 @@
<button type="submit" class="button-6" role="button" name="preprocess" form="preprocess-form">Go!</button>
</div>
</div>
<br>
<br>
<div class="row">
{% if pca %}
<div class="col-sm d-flex justify-content-center">
{{ pca|safe }}
</div>
{% endif %}
{% if tsne %}
<div class="col-sm d-flex justify-content-center">
{{ tsne|safe }}
</div>
{% endif %}
{% if pca %}
<div class="col d-flex justify-content-center" id="pca">
{{ pca|safe }}
</div>
{% endif %}
</div>
<br>
<br>
</div>
<div class="container-fluid">
<form action="{% url 'home' %}" method="POST" id="traintest-form">
{% csrf_token %}
@ -147,7 +164,7 @@
<option type="submit" value="logit" selected>Logistic Regression</option>
<option type="submit" value="xgb">XGBoost</option>
<option type="submit" value="dt">Decision Tree</option>
<option type="submit" value="rt">Random Forest</option>
<option type="submit" value="rf">Random Forest</option>
</select>
</label>
</div>
@ -200,6 +217,7 @@
{{ fig2|safe }}
</div>
{% endif %}
{% if clas_report %}
<div class="col-sm d-flex justify-content-center">
<div class="scrollit">
@ -209,4 +227,97 @@
{% endif %}
</div>
</div>
<br>
<br>
<div class="container-fluid">
<div class="row">
<section>
<label style="display:flex;
flex-direction:column;
align-items: center;">
<h2>
<i class="fas fa-cog"></i> Counterfactuals
</h2>
<h5>
Pick a point on the graph
</h5>
</label>
</section>
<div class="col-sm d-flex justify-content-center">
</div>
</div>
{% if tsne %}
<div class="row">
<div class="col-md justify-content-center" id="tsne">
{{ tsne|safe }}
</div>
</div>
{% endif %}
<div class="row">
<form action="{% url 'home' %}" method="POST" id="cf-form">
{% csrf_token %}
<div class="col-md d-flex justify-content-center">
<div class="scrollit" class="cftable" name="cfrow" id="cftable" style="border: 1px solid #bbb; padding: 5px; display: none; max-width: 65%;" >
</div>
</div>
</form>
</div>
<div class="row">
<div class="col-sm d-flex justify-content-center">
<button id="cfbtn" type="submit" class="button-6" role="button" name="cf" form="cf-form" style="display:none;">Run Counterfactuals!</button>
</div>
</div>
{% if cfrow_og and cfrow_cf %}
<div class="row" id="og_cf_headers">
<div class="col d-flex justify-content-center">
<h2>
Original data
</h2>
</div>
<span style="display: none; width: 48px; height: 48px;" id="reset_div">
</span>
<span class="loader" id="loader_cf" style="display: none;">
</span>
<div class="col d-flex justify-content-center" style="select:invalid { color: gray; }">
<label style="display:flex;
flex-direction:column;
align-items: center;">
<select id="cfrow_id" name="cfrow_id" style="scale:1.2;">
<option value="" disabled selected hidden>Pick counterfactual...</option>
{% for row in cfdf_rows %}
<option value={{row}}>Counterfactual {{row}}</option>
{% endfor %}
</select>
</label>
</div>
</div>
<div class="row" id="og_cf_row">
<div class="col-md d-flex justify-content-center" >
<div class="scrollit" id="original_data">
{{ cfrow_og|safe}}
</div>
</div>
<div class="col-md d-flex justify-content-center" >
<div class="scrollit" id="counterfactual">
<div id="counterfactual_selected">
{{ cfrow_cf|safe }}
</div>
</div>
</div>
</div>
{% endif %}
<br>
<br>
</div>
<div class="container-fluid">
</div>
<br>
<br>
{% endblock content%}

@ -1,10 +1,6 @@
from django.urls import path, include
from . import views
from . import models
urlpatterns = [
path('', views.home, name="home"),
path('preprocess', views.preprocess, name="preprocess"),
path('stats', views.stats, name="stats"),
]

@ -1,5 +0,0 @@
def stats(feature1, feature2, df):
import plotly.express as px
fig = px.scatter(df, x=feature1, y=feature2, color='Churn')
fig = fig.to_html(full_html=False)
return fig

@ -1,27 +1,26 @@
from django.shortcuts import render, redirect
from django.shortcuts import render, HttpResponse
import pandas as pd
from django.core.files.storage import FileSystemStorage
import json
import IPython
import pickle, os
from sklearn.impute import SimpleImputer
import numpy as np
from pandas.api.types import is_string_dtype
from pandas.api.types import is_numeric_dtype
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import classification_report
import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import LabelEncoder
import joblib
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import dice_ml
import plotly.graph_objects as go
clas_report = None
FILE_NAME = "dataset.csv"
PROCESS_FILE_NAME = "dataset_preprocessed.csv"
def home(request):
global clas_report
# request.session.flush()
if "fig" in request.session:
fig = request.session.get("fig")
@ -43,6 +42,26 @@ def home(request):
else:
tsne = None
if "cfrow_og" in request.session:
cfrow_og = request.session.get("cfrow_og")
else:
cfrow_og = None
if "cfrow_cf" in request.session:
cfrow_cf = request.session.get("cfrow_cf")
else:
cfrow_cf = None
if "cfdf_rows" in request.session:
cfdf_rows = request.session.get("cfdf_rows")
else:
cfdf_rows = None
if "clas_report" in request.session:
clas_report = request.session.get("clas_report")
else:
clas_report = None
if "excel_file_name" in request.session:
excel_file_name = request.session.get("excel_file_name")
else:
@ -60,15 +79,91 @@ def home(request):
excel_file_name_preprocessed = PROCESS_FILE_NAME
request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed
# ajax request condition
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
if request.POST.get("action") == "click_graph":
# ajax request for graph click
# given the ide of the clicked point (through ajax request)
# locate the respective dataframe row
df = pd.read_csv(excel_file_name_preprocessed)
id = request.POST["row"]
row = df.iloc[[int(id)]]
projections = request.session.get("tsne_projection")
# projections array is a list of pairs with the (x, y)
# coordinates for a point in tsne. These are actual absolute
# coordinates and not SVG.
# Now save the info for use in the future
request.session["clicked_point"] = projections[int(id)]
request.session["cfrow_id"] = request.POST["row"]
request.session["cfrow_og"] = row.to_html()
context = {"row": row.to_html()}
elif request.POST.get("action") == "counterfactual_select":
# if <select> element is used, and a specific counterfactual
# is inquired to be demonstrated:
df = pd.read_csv("counterfactuals.csv")
id = request.POST["row"]
row = df.iloc[[int(id)]]
projections = request.session.get("tsne_projection")
fig = joblib.load("tsne_cfs.sav")
# tsne_cfs is a merged scatter of 2 other scatters
# scatter[0] contains all the points of tsne
# scatter[1] contains all the counterfactual points
# + the point for which counterfactuals were
# computed...
# blur out all the points of the first scatter
fig.data[0].update(
opacity=0.3,
)
# from the second scatter (scatter[1]) locate the clicked_point
# coordinates and keep the index
l = fig.data[1]
clicked_id = -1
for clicked_id, item in enumerate(list(zip(l.x, l.y))):
if (
item[0] == request.session.get("clicked_point")[0]
and item[1] == request.session.get("clicked_point")[1]
):
break
# id is the index of the respective counter factual
# clicked_id is the index of the original point
# Blur all the points apart from these 2
fig.data[1].update(
selectedpoints=[id, clicked_id],
unselected=dict(
marker=dict(
opacity=0.3,
)
),
)
context = {"row": row.to_html(), "fig": fig.to_html()}
elif request.POST.get("action") == "reset_graph":
fig = request.session.get("tsne")
context = {"fig": fig}
return HttpResponse(json.dumps(context))
df = pd.DataFrame()
if request.method == "POST":
if "csv" in request.POST:
excel_file = request.FILES["excel_file"]
excel_file_name = request.FILES["excel_file"].name
fig = None
fig2 = None
pca = None
tsne = None
clas_report = None
cfrow_cf = None
cfrow_og = None
# here we dont use the name of the file since the
# uploaded file is not yet saved
@ -79,11 +174,11 @@ def home(request):
# fs.save(excel_file_name, excel_file)
df = pd.read_csv(excel_file)
df.drop(["id"], axis=1, inplace=True)
# df.drop(["id"], axis=1, inplace=True)
df.to_csv(excel_file_name, index=False)
feature1 = df.columns[0]
feature2 = df.columns[1]
feature1 = df.columns[3]
feature2 = df.columns[2]
request.session["feature1"] = feature1
request.session["feature2"] = feature2
@ -108,7 +203,8 @@ def home(request):
mode = request.POST.get("colorRadio")
model = request.POST.get("model")
test_size = float(request.POST.get("split_input"))
print(test_size, mode, model)
request.session["model"] = model
if mode == "train":
if model == "logit":
con = training(excel_file_name_preprocessed, "logit", test_size)
@ -122,6 +218,9 @@ def home(request):
elif model == "svm":
con = training(excel_file_name_preprocessed, "svm", test_size)
elif model == "rf":
con = training(excel_file_name_preprocessed, "rf", test_size)
fig2 = con["fig2"]
clas_report = con["clas_report"].to_html()
elif mode == "test":
@ -137,6 +236,9 @@ def home(request):
elif model == "svm":
con = testing(excel_file_name_preprocessed, "svm")
elif model == "rf":
con = training(excel_file_name_preprocessed, "rf", test_size)
fig2 = con["fig2"]
clas_report = con["clas_report"].to_html()
@ -146,21 +248,16 @@ def home(request):
# if file for preprocessing does not exist create it
# also apply basic preprocessing
if os.path.exists(excel_file_name_preprocessed) == False:
# generate filename
idx = excel_file_name.index(".")
excel_file_name_preprocessed = (
excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:]
)
# save file for preprocessing
preprocess_df = pd.read_csv(excel_file_name)
fs = FileSystemStorage() # defaults to MEDIA_ROOT
request.session["excel_file_name_preprocessed"] = (
excel_file_name_preprocessed
)
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
preprocess_df.drop(
["perimeter_mean", "area_mean"], axis=1, inplace=True
)
@ -180,18 +277,21 @@ def home(request):
inplace=True,
)
# preprocess_df.drop(["id"], axis=1, inplace=True)
le = LabelEncoder()
preprocess_df["diagnosis"] = le.fit_transform(
preprocess_df["diagnosis"]
)
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
else:
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
preprocess(preprocess_df, value_list, excel_file_name_preprocessed)
preprocess_df.drop(["id"], axis=1, inplace=True)
# PCA
pca = PCA()
pca.fit(preprocess_df)
pca.fit(preprocess_df.loc[:, "radius_mean":])
exp_var_cumul = np.cumsum(pca.explained_variance_ratio_)
pca = px.area(
x=range(1, exp_var_cumul.shape[0] + 1),
@ -199,31 +299,112 @@ def home(request):
labels={"x": "# Components", "y": "Explained Variance"},
).to_html()
features = preprocess_df.loc[:, :"compactness_se"]
# tSNE
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(features)
projections = tsne.fit_transform(
preprocess_df.drop(["diagnosis"], axis=1).values
)
tsne_df = pd.DataFrame(
{
"0": projections[:, 0],
"1": projections[:, 1],
"diagnosis": preprocess_df["diagnosis"],
}
)
tsne = px.scatter(
projections,
x=0,
y=1,
color=preprocess_df.diagnosis,
labels={"color": "diagnosis"},
).to_html()
tsne_df,
x="0",
y="1",
color="diagnosis",
color_continuous_scale=px.colors.sequential.Rainbow,
)
tsne.update_layout(clickmode="event+select")
# tsne_opacity.update_layout(clickmode="event+select")
pickle.dump(tsne, open("tsne.sav", "wb"))
tsne = tsne.to_html()
request.session["tsne_projection"] = projections.tolist()
elif "cf" in request.POST:
excel_file_name_preprocessed = request.session.get(
"excel_file_name_preprocessed"
)
df = pd.read_csv(excel_file_name_preprocessed)
df_id = request.session.get("cfrow_id")
model = request.session.get("model")
row = df.iloc[[int(df_id)]]
counterfactuals(row, model, excel_file_name_preprocessed)
# get coordinates of the clicked point (saved during the actual click)
clicked_point = request.session.get("clicked_point")
clicked_point_df = pd.DataFrame(
{
"0": clicked_point[0],
"1": clicked_point[1],
"diagnosis": row.diagnosis,
}
)
clicked_point_df.reset_index(drop=True)
# tSNE
cf_df = pd.read_csv("counterfactuals.csv")
# get rows count
request.session["cfdf_rows"] = cf_df.index.values.tolist()
# select a cf randomly for demonstration
request.session["cfrow_cf"] = cf_df.iloc[:1].to_html()
df_merged = pd.concat(
[cf_df, df.drop("id", axis=1)], ignore_index=True, axis=0
)
tsne_cf = TSNE(n_components=2, random_state=0)
projections = tsne_cf.fit_transform(
df_merged.drop(["diagnosis"], axis=1).values
)
cf_df = pd.DataFrame(
{
"0": projections[:3, 0],
"1": projections[:3, 1],
"diagnosis": cf_df.diagnosis.iloc[:3],
}
)
cf_df = pd.concat([cf_df, clicked_point_df], ignore_index=True, axis=0)
tsne = joblib.load("tsne.sav")
cf_s = px.scatter(
cf_df,
x="0",
y="1",
color="diagnosis",
color_continuous_scale=px.colors.sequential.Rainbow,
)
cf_s.update_traces(
marker=dict(
size=10,
symbol="circle",
)
)
tsne.add_trace(cf_s.data[0])
pickle.dump(tsne, open("tsne_cfs.sav", "wb"))
tsne = tsne.to_html()
else:
if os.path.exists(excel_file_name) == False:
excel_file_name = "dataset.csv"
request.session["excel_file_name"] = excel_file_name
df = pd.read_csv(excel_file_name)
fig2 = None
df = pd.read_csv(excel_file_name)
# just random columns to plot
feature1 = df.columns[0]
feature2 = df.columns[1]
feature1 = df.columns[3]
feature2 = df.columns[2]
request.session["feature1"] = feature1
request.session["feature2"] = feature2
fig = stats(
@ -237,10 +418,11 @@ def home(request):
request.session["fig2"] = fig2
request.session["pca"] = pca
request.session["tsne"] = tsne
request.session["clas_report"] = clas_report
data_to_display = df[:5].to_html()
data_to_display = df[:10].to_html(index=False)
request.session["data_to_display"] = data_to_display
labels = df.columns
labels = df.columns[2:]
context = {
"data_to_display": data_to_display,
"excel_file": excel_file_name,
@ -252,6 +434,9 @@ def home(request):
"clas_report": clas_report,
"pca": pca,
"tsne": tsne,
"cfrow_og": cfrow_og,
"cfrow_cf": cfrow_cf,
"cfdf_rows": cfdf_rows,
}
return render(request, "base/home.html", context)
@ -278,6 +463,9 @@ def stats(name, feature1, feature2):
else:
# they both are categorical so do scatter
fig = px.histogram(df, x=feature1, color=feature2)
fig.update_layout(clickmode="event+select")
fig = fig.to_html(full_html=False)
return fig
@ -285,27 +473,30 @@ def stats(name, feature1, feature2):
def preprocess(data, value_list, name):
from sklearn.preprocessing import StandardScaler
ids = data["id"]
data = data.drop(["id"], axis=1)
for type in value_list:
if type == "std":
# define standard scaler
scaler = StandardScaler()
y = data["diagnosis"]
if is_numeric_dtype(data["diagnosis"]):
# if class column is numeric do not
# apply preprocessing
data = data.drop(["diagnosis"], axis=1)
# if class column is numeric do not
# apply preprocessing
data = data.drop(["diagnosis"], axis=1)
# transform data
cols = data.select_dtypes(np.number).columns
data[cols] = scaler.fit_transform(data[cols])
y = y.to_frame()
data = data.join(y)
data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
if type == "onehot":
data = pd.get_dummies(data)
if type == "imp":
y = data["diagnosis"]
data = data.drop(["diagnosis"], axis=1)
data_numeric = data.select_dtypes(exclude=["object"])
data_categorical = data.select_dtypes(exclude=["number"])
imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent")
@ -318,17 +509,20 @@ def preprocess(data, value_list, name):
[data_numeric, data_categorical], axis=1, ignore_index=False
)
data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
new = pd.concat([ids.to_frame(), data], axis=1, ignore_index=False)
os.remove(name)
data.to_csv(name, index=False)
new.to_csv(name, index=False)
return
def training(name, type, test_size=0.7):
data = pd.read_csv(name)
X = data.drop("id", axis=1)
from sklearn.model_selection import train_test_split
y = data["diagnosis"]
X = data.drop("diagnosis", axis=1)
y = X["diagnosis"]
X = X.drop("diagnosis", axis=1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42
)
@ -377,6 +571,16 @@ def training(name, type, test_size=0.7):
importance = svc.coef_[0]
model = svc
if "rf" == type:
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
filename = "rf.sav"
importance = rf.feature_importances_
model = rf
clas_report = classification_report(y_test, y_pred, output_dict=True)
clas_report = pd.DataFrame(clas_report).transpose()
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
@ -423,6 +627,13 @@ def testing(name, type):
importance = svc.coef_[0]
model = svc
if "rf" == type:
filename = "rf.sav"
rf = joblib.load(filename)
y_pred = rf.predict(X_test)
importance = rf.feature_importances_
model = rf
clas_report = classification_report(y_test, y_pred, output_dict=True)
clas_report = pd.DataFrame(clas_report).transpose()
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
@ -433,3 +644,73 @@ def testing(name, type):
"clas_report": clas_report,
}
return con
def counterfactuals(row, model, preprocessed_file):
df = pd.read_csv(preprocessed_file)
df = df.drop("id", axis=1)
row = row.drop(["diagnosis", "id"], axis=1)
if "logit" == model:
filename = "lg.sav"
clf = joblib.load(filename)
model = clf
if "xgb" == model:
filename = "xgb.sav"
xgb = joblib.load(filename)
model = xgb
if "dt" == model:
filename = "dt.sav"
dt = joblib.load(filename)
model = dt
if "svm" == model:
filename = "svc.sav"
svc = joblib.load(filename)
model = svc
if "rf" == model:
filename = "rf.sav"
rf = joblib.load(filename)
model = rf
d = dice_ml.Data(
dataframe=df,
continuous_features=[
"radius_mean",
"texture_mean",
"smoothness_mean",
"compactness_mean",
"concavity_mean",
"symmetry_mean",
"fractal_dimension_mean",
"radius_se",
"texture_se",
"smoothness_se",
"compactness_se",
"concavity_se",
"concave points_se",
"symmetry_se",
"fractal_dimension_se",
"compactness_worst",
"concavity_worst",
"concave points_worst",
"fractal_dimension_worst",
],
outcome_name="diagnosis",
)
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)
dice_exp = exp.generate_counterfactuals(
row, # The data from the 1st row of our dataframe
total_CFs=3, # Total number of Counterfactual Examples we want to print out. There can be multiple.
desired_class="opposite", # We want to convert the quality to the opposite one.
)
# save cfs
dice_exp.cf_examples_list[0].final_cfs_df.to_csv(
path_or_buf="counterfactuals.csv", index=False
)
return dice_exp._cf_examples_list

Binary file not shown.

3
extremum/routing.py Normal file

@ -0,0 +1,3 @@
from channels.routing import ProtocolTypeRouter
application = ProtocolTypeRouter({})

@ -40,6 +40,9 @@ INSTALLED_APPS = [
"django.contrib.staticfiles",
"base.apps.BaseConfig",
"bootstrap5",
"django_plotly_dash.apps.DjangoPlotlyDashConfig",
"channels",
"channels_redis",
]
MIDDLEWARE = [
@ -105,6 +108,7 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization
# https://docs.djangoproject.com/en/5.0/topics/i18n/
X_FRAME_OPTIONS = 'SAMEORIGIN'
LANGUAGE_CODE = "en-us"
@ -114,6 +118,32 @@ USE_I18N = True
USE_TZ = True
CRISPY_TEMPLATE = "bootsrap5"
ASGI_APPLICATION = "extremum.routing.applications"
CHANNEL_LAYERS = {
'default': {
'BACKEND': 'channels_redis.core.RedisChannelLayer',
'CONFIG': {
'hosts': [('127.0.0.1', 6379),],
},
},
}
SSTATICFILRES_FINDERS = [
"django.contrib.staticfiles.finders.FileSystemFinder",
"django.contrib.staticfiles.finders.AppDirectoriesFinder",
"django_plotly_dash.finder.DashAssetFinder",
"dajango_plotly_dash.finders.DashComponentsFinder",
]
PLOTLY_COMPONENTS = [
"dash_core_components",
"dash_html_components",
"dash_renderer",
"dpd_components",
]
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/5.0/howto/static-files/

@ -11,7 +11,10 @@
<title>EXTREMUM</title>
<meta name="viewport" content="'width=device-width, initial-scale=1" />
<link rel="stylesheet" href="{% static 'css/style.css' %}">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js" charset="utf-8"></script>
<script src="https://code.jquery.com/jquery-3.7.1.js" integrity="sha256-eKhayi8LEQwp4NKxN+CfCh+3qOVUtJn3QNZ0TciWLP4=" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/svg.js/3.2.0/svg.min.js" charset="utf-8"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css">
</head>
@ -29,12 +32,13 @@
<br>
<br>
<body onscroll="setScroll()" onload="setScreen()">
{% block content %}
{% endblock content %}
<script src="{% static 'js/click_on_graph.js' %}"></script>
<script src="{% static 'js/hide_seek.js' %}"></script>
<script src="{% static 'js/slider.js' %}"></script>
<script src="{% static 'js/keep_scroll_on_load.js' %}"></script>
<script src="{% static 'js/counterfactuals.js' %}"></script>
<script src="{% static 'js/click_reset.js' %}"></script>
</body>
</html>