Counterfactuals demonstration and computation for tabular

This commit is contained in:
atla8167 2024-07-01 18:17:14 +03:00
parent 6047fad7ce
commit 94c682362f
21 changed files with 832 additions and 82 deletions

Binary file not shown.

Binary file not shown.

@ -0,0 +1,54 @@
from base import views
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
from django_plotly_dash import DjangoDash
import plotly.express as px
import pandas as pd
from sklearn.manifold import TSNE
import json
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = DjangoDash("Tsne", external_stylesheets=external_stylesheets)
excel_file_name_preprocessed = "breast-cancer_preprocessed.csv"
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
preprocess_df.drop(["id"], axis=1, inplace=True)
features = preprocess_df.loc[:, :"compactness_se"]
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(features)
tsne = px.scatter(
projections,
x=0,
y=1,
color=preprocess_df.diagnosis.astype(str),
labels={"color": "diagnosis"},
)
tsne.update_layout(clickmode="event+select")
app.layout = html.Div(
[
dcc.Graph(id="scatter-plot", config={"displayModeBar": False}, figure=tsne),
html.Div(
className="row",
children=[
html.Div(
[
dcc.Input(
id="click-data",
type="hidden",
className="three columns",
)
]
)
],
),
]
)
@app.callback(Output("click-data", "children"), Input("scatter-plot", "clickData"))
def display_click_data(clickData, request):
# dictionary of list of dictionary ???
views.counterfactuals(clickData, request)
return json.dumps(clickData, indent=2)

@ -0,0 +1,16 @@
# Generated by Django 4.2.13 on 2024-06-17 08:51
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("base", "0002_document_delete_upload"),
]
operations = [
migrations.DeleteModel(
name="Document",
),
]

@ -288,12 +288,13 @@ nav input {
height: min-content; height: min-content;
max-height: 40%; max-height: 40%;
max-width: 35%; max-width: 35%;
border: 1px solid #bbb;
padding: 5px;
} }
.dataframe { .dataframe {
font-size: 9pt; font-size: 9pt;
font-family: Arial; font-family: Arial;
border-collapse: collapse;
font-size: 0.9em; font-size: 0.9em;
} }
@ -306,6 +307,13 @@ nav input {
.dataframe td { .dataframe td {
padding: 12px 15px; padding: 12px 15px;
text-align: left; text-align: left;
border: black;
border-collapse: separate;
}
.dataframe .clickedrow th,
.dataframe .clickedrow td {
background-color: #c6bdf8;
} }
.dataframe tbody tr { .dataframe tbody tr {
@ -319,31 +327,36 @@ nav input {
.dataframe tbody tr:last-of-type { .dataframe tbody tr:last-of-type {
border-bottom: 2px solid #009879; border-bottom: 2px solid #009879;
} }
.dataframe tbody tr:hover {
background-color: #e8e5f9;
}
/* plotly toolbar */ /* plotly toolbar */
.modebar { .modebar {
display: none !important; display: none !important;
} }
/* button */ /* button */
/* CSS */ /* CSS */
.button-6 { .button-6 {
align-items: center; align-items: center;
background-color: #FFFFFF; background-color: #ffffff;
border: 1px solid rgba(0, 0, 0, 0.1); border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: .25rem; border-radius: 0.25rem;
box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0; box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0;
box-sizing: border-box; box-sizing: border-box;
color: rgba(0, 0, 0, 0.85); color: rgba(0, 0, 0, 0.85);
cursor: pointer; cursor: pointer;
display: inline-flex; display: inline-flex;
font-family: system-ui,-apple-system,system-ui,"Helvetica Neue",Helvetica,Arial,sans-serif; font-family: system-ui, -apple-system, system-ui, "Helvetica Neue", Helvetica,
Arial, sans-serif;
font-size: 16px; font-size: 16px;
font-weight: 600; font-weight: 600;
justify-content: center; justify-content: center;
line-height: 1.25; line-height: 1.25;
margin: 0; margin: 0;
min-height: 3rem; min-height: 3rem;
padding: calc(.875rem - 1px) calc(1.5rem - 1px); padding: calc(0.875rem - 1px) calc(1.5rem - 1px);
position: relative; position: relative;
text-decoration: none; text-decoration: none;
transition: all 250ms; transition: all 250ms;
@ -366,7 +379,7 @@ nav input {
} }
.button-6:active { .button-6:active {
background-color: #F0F0F1; background-color: #f0f0f1;
border-color: rgba(0, 0, 0, 0.15); border-color: rgba(0, 0, 0, 0.15);
box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px; box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px;
color: rgba(0, 0, 0, 0.65); color: rgba(0, 0, 0, 0.65);
@ -379,7 +392,7 @@ nav input {
display: flex; display: flex;
flex-wrap: wrap; flex-wrap: wrap;
border-radius: 0.5rem; border-radius: 0.5rem;
background-color: #EEE; background-color: #eee;
box-sizing: border-box; box-sizing: border-box;
box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06); box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06);
padding: 0.25rem; padding: 0.25rem;
@ -403,12 +416,32 @@ nav input {
justify-content: center; justify-content: center;
border-radius: 0.5rem; border-radius: 0.5rem;
border: none; border: none;
padding: .5rem 0; padding: 0.5rem 0;
color: rgba(51, 65, 85, 1); color: rgba(51, 65, 85, 1);
transition: all .15s ease-in-out; transition: all 0.15s ease-in-out;
} }
.radio-inputs .radio input:checked + .name { .radio-inputs .radio input:checked + .name {
background-color: #fff; background-color: #fff;
font-weight: 600; font-weight: 600;
} }
.loader {
width: 48px;
height: 48px;
border: 5px solid #fff;
border-bottom-color: #ff3d00;
border-radius: 50%;
display: inline-block;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

@ -0,0 +1,108 @@
$(document).ready(function () {
// When click on graph (tsne)
// iterate through the elements and find the one with opacity 1
// the rest have opacity 0.2 when a button is clicked...
// Case 1: They are all opacity 1!
// In that case we cannot check all of them
// it would be inneficient,
// though if we just find one more we can determine
// that all the rest are also opacity: 1 and thus
// there was no click on a specific point
// Case 2: There was actually a click on a point!
// In that case we need to find that point. Poltly
// does not save all the points on one <g> element
// Instead on the particular example it would use 2
// <g> elements with multiple <points> in them.
// We iterate through all the collections and all the points
// until we find the point element with opacity 1.
// That way feels inefficient but in reality rarely will
// the last ever point of the last collection be selected
// and even then the complexity is just O(n). Not great
// but not bad either. It would help to be able to determine
// if a point is in collection i or j but I am not there
// yet.
document.getElementById('tsne').addEventListener('click', function () {
// remove original and counterfactual columns
if (document.getElementById("og_cf_row") && document.getElementById("og_cf_headers")) {
document.getElementById("og_cf_row").style.display = "none";
document.getElementById("og_cf_headers").style.display = "none";
}
var tsne = document.getElementById('tsne');
var points_array = tsne.getElementsByClassName("points")
var counter_break = 0;
var final_position = 0;
var j;
var i;
for (i = 0; i < points_array.length; i++) {
var points = points_array[i]
var children = points.children
if (getComputedStyle(children[0]).opacity == 1 && getComputedStyle(children[1]).opacity == 1) {
break;
}
console.log(children.length)
for (j = 0; j < children.length; j++) {
var st = children[j]
opacity = getComputedStyle(st).opacity
if (opacity == 1) {
console.log("i: ", i)
console.log("J: ", j)
console.log("children.length: ", children.length)
counter_break++;
final_position += j;
console.log("final_position: ", final_position)
break;
}
}
if (counter_break) {
break;
}
final_position += j;
}
if (counter_break == 0) {
$("#cfrow").remove();
document.getElementById("cftable").style.display = "none";
document.getElementById("cfbtn").style.display = "none";
} else if (counter_break == 1) {
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': final_position, 'action': "click_graph" },
success: function (ret) {
row = JSON.parse(ret)
row = row["row"]
var tb = document.createElement('table');
tb.innerHTML = row.trim();
tb.setAttribute("id", "cfrow");
tb.setAttribute("class", "dataframe")
if (document.getElementById("cfrow")) {
$("#cfrow").remove();
document.getElementById("cftable").style.display = "none";
}
$("#cftable").append(tb);
document.getElementById("cftable").style.display = "block";
document.getElementById("cfbtn").style.display = "block";
},
error: function (row) {
console.log("it didnt work");
}
});
} else if (counter_break > 1) {
console.log("All opacity 1")
}
});
});

@ -0,0 +1,30 @@
$(document).ready(function () {
//Highlight clicked row
$('.dataframe td').on('click', function () {
// Remove previous highlight class
$(this).closest('.dataframe').find('tr.clickedrow').removeClass('clickedrow');
// add highlight to the parent tr of the clicked td
$(this).closest('tr').addClass("clickedrow");
// fetch row
var $row = this.closest('tr')
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': $row.innerHTML.trim().split("\n").map(line => line.trim()).join("\n") },
success: function (id) {
var myPlot = document.getElementById('tsne');
var paths = myPlot.querySelector('.points');
var path = paths.children[id];
path.setAttribute("style", "opacity: 1; stroke-width: 0px; fill: green; fill-opacity: 1; scale: 2.0;");
console.log(path)
},
error: function (id) {
console.log("it didnt work");
}
});
});
});

@ -0,0 +1,27 @@
$(document).ready(function () {
//Highlight clicked row
document.getElementById('reset_div').addEventListener('click', function () {
// on click reset the graph
// if reset button exists, hide it
document.getElementById("reset_div").style.display = "none";
document.getElementById("loader_cf").style.display = "block";
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: {'action': "reset_graph" },
success: function (ret) {
document.getElementById("loader_cf").style.display = "none";
ret = JSON.parse(ret)
fig = ret["fig"]
document.getElementById("tsne").innerHTML = "";
$("#tsne").append(fig)
document.getElementById("reset_div").style.display = "none";
},
error: function (ret) {
}
});
});
});

@ -0,0 +1,62 @@
$(document).ready(function () {
var mySelect = document.getElementById('cfrow_id');
mySelect.onchange = (event) => {
var rowId = event.target.value;
// inputText is the row id of the counter factual
// Send ajax request to backend and inquire for the
// respective counterfactual. When it is acquired
// replace current counterfactual view with that...
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
document.getElementById("loader_cf").style.display = "block";
// if reset button exists, hide it
if (document.getElementById("reset_div"))
document.getElementById("reset_div").style.display = "none";
$("#show_image").show()
$.ajax({
method: 'POST',
url: '',
headers: { 'X-CSRFToken': csrftoken },
data: { 'row': rowId, 'action': "counterfactual_select" },
success: function (ret) {
// stop loader
document.getElementById("loader_cf").style.display = "none";
// add <input type="reset" value="Reset">
if (!document.getElementById("reset")) {
var reset = document.createElement('input')
reset.setAttribute("id", "reset")
reset.setAttribute("type", "reset")
reset.setAttribute("value", "Reset")
$("#reset_div").append(reset)
}
// if reset button has been created once just display it
document.getElementById("reset_div").style.display = "block";
var tb = document.createElement('table');
ret = JSON.parse(ret)
row = ret["row"]
tb.innerHTML = row.trim();
tb.setAttribute("id", "counterfactual_selected");
tb.setAttribute("class", "dataframe")
if (document.getElementById("counterfactual_selected")) {
$("#counterfactual_selected").remove();
}
$("#counterfactual").append(tb);
document.getElementById("counterfactual_selected").style.display = "block";
fig = ret["fig"]
document.getElementById("tsne").innerHTML = "";
$("#tsne").append(fig)
},
error: function (row) {
console.log("it didnt work");
}
});
}
});

@ -3,7 +3,7 @@ function setScreen() {
window.scrollTo(0, yScreen); window.scrollTo(0, yScreen);
} }
function setScroll() { function setScroll() {
var yScroll = window.pageYOffset; var yScroll = window.scrollY;
localStorage.setItem("yPos", yScroll); localStorage.setItem("yPos", yScroll);
} }
function clearScreen() { function clearScreen() {

@ -2,6 +2,26 @@
{% block content %} {% block content %}
{% load static %} {% load static %}
<!-- <div class="container-fluid">
<div class="row">
<div class="col-sm d-flex justify-content-center" >
<div class="dataset">
<div class="form-check">
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault1">
<label class="form-check-label" for="flexRadioDefault1">
Breast Cancer Dataset
</label>
</div>
<div class="form-check">
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault2">
<label class="form-check-label" for="flexRadioDefault2">
Other
</label>
</div>
</div>
</div>
</div>
</div> -->
<div class="container-fluid"> <div class="container-fluid">
<div class="row"> <div class="row">
<div class="col-sm d-flex justify-content-center"> <div class="col-sm d-flex justify-content-center">
@ -68,8 +88,10 @@
</div> </div>
</div> </div>
</div> </div>
<br> <br>
<br> <br>
<div class="container-fluid"> <div class="container-fluid">
<div class="row"> <div class="row">
<section> <section>
@ -113,24 +135,19 @@
</div> </div>
</div> </div>
<br>
<br>
<div class="row"> <div class="row">
{% if pca %} {% if pca %}
<div class="col-sm d-flex justify-content-center"> <div class="col d-flex justify-content-center" id="pca">
{{ pca|safe }} {{ pca|safe }}
</div> </div>
{% endif %} {% endif %}
{% if tsne %}
<div class="col-sm d-flex justify-content-center">
{{ tsne|safe }}
</div> </div>
{% endif %}
</div>
<br>
<br>
</div> </div>
<div class="container-fluid"> <div class="container-fluid">
<form action="{% url 'home' %}" method="POST" id="traintest-form"> <form action="{% url 'home' %}" method="POST" id="traintest-form">
{% csrf_token %} {% csrf_token %}
@ -147,7 +164,7 @@
<option type="submit" value="logit" selected>Logistic Regression</option> <option type="submit" value="logit" selected>Logistic Regression</option>
<option type="submit" value="xgb">XGBoost</option> <option type="submit" value="xgb">XGBoost</option>
<option type="submit" value="dt">Decision Tree</option> <option type="submit" value="dt">Decision Tree</option>
<option type="submit" value="rt">Random Forest</option> <option type="submit" value="rf">Random Forest</option>
</select> </select>
</label> </label>
</div> </div>
@ -200,6 +217,7 @@
{{ fig2|safe }} {{ fig2|safe }}
</div> </div>
{% endif %} {% endif %}
{% if clas_report %} {% if clas_report %}
<div class="col-sm d-flex justify-content-center"> <div class="col-sm d-flex justify-content-center">
<div class="scrollit"> <div class="scrollit">
@ -209,4 +227,97 @@
{% endif %} {% endif %}
</div> </div>
</div> </div>
<br>
<br>
<div class="container-fluid">
<div class="row">
<section>
<label style="display:flex;
flex-direction:column;
align-items: center;">
<h2>
<i class="fas fa-cog"></i> Counterfactuals
</h2>
<h5>
Pick a point on the graph
</h5>
</label>
</section>
<div class="col-sm d-flex justify-content-center">
</div>
</div>
{% if tsne %}
<div class="row">
<div class="col-md justify-content-center" id="tsne">
{{ tsne|safe }}
</div>
</div>
{% endif %}
<div class="row">
<form action="{% url 'home' %}" method="POST" id="cf-form">
{% csrf_token %}
<div class="col-md d-flex justify-content-center">
<div class="scrollit" class="cftable" name="cfrow" id="cftable" style="border: 1px solid #bbb; padding: 5px; display: none; max-width: 65%;" >
</div>
</div>
</form>
</div>
<div class="row">
<div class="col-sm d-flex justify-content-center">
<button id="cfbtn" type="submit" class="button-6" role="button" name="cf" form="cf-form" style="display:none;">Run Counterfactuals!</button>
</div>
</div>
{% if cfrow_og and cfrow_cf %}
<div class="row" id="og_cf_headers">
<div class="col d-flex justify-content-center">
<h2>
Original data
</h2>
</div>
<span style="display: none; width: 48px; height: 48px;" id="reset_div">
</span>
<span class="loader" id="loader_cf" style="display: none;">
</span>
<div class="col d-flex justify-content-center" style="select:invalid { color: gray; }">
<label style="display:flex;
flex-direction:column;
align-items: center;">
<select id="cfrow_id" name="cfrow_id" style="scale:1.2;">
<option value="" disabled selected hidden>Pick counterfactual...</option>
{% for row in cfdf_rows %}
<option value={{row}}>Counterfactual {{row}}</option>
{% endfor %}
</select>
</label>
</div>
</div>
<div class="row" id="og_cf_row">
<div class="col-md d-flex justify-content-center" >
<div class="scrollit" id="original_data">
{{ cfrow_og|safe}}
</div>
</div>
<div class="col-md d-flex justify-content-center" >
<div class="scrollit" id="counterfactual">
<div id="counterfactual_selected">
{{ cfrow_cf|safe }}
</div>
</div>
</div>
</div>
{% endif %}
<br>
<br>
</div>
<div class="container-fluid">
</div>
<br>
<br>
{% endblock content%} {% endblock content%}

@ -1,10 +1,6 @@
from django.urls import path, include from django.urls import path, include
from . import views from . import views
from . import models
urlpatterns = [ urlpatterns = [
path('', views.home, name="home"), path('', views.home, name="home"),
path('preprocess', views.preprocess, name="preprocess"),
path('stats', views.stats, name="stats"),
] ]

@ -1,5 +0,0 @@
def stats(feature1, feature2, df):
import plotly.express as px
fig = px.scatter(df, x=feature1, y=feature2, color='Churn')
fig = fig.to_html(full_html=False)
return fig

@ -1,27 +1,26 @@
from django.shortcuts import render, redirect from django.shortcuts import render, HttpResponse
import pandas as pd import pandas as pd
from django.core.files.storage import FileSystemStorage import json
import IPython
import pickle, os import pickle, os
from sklearn.impute import SimpleImputer from sklearn.impute import SimpleImputer
import numpy as np import numpy as np
from pandas.api.types import is_string_dtype
from pandas.api.types import is_numeric_dtype from pandas.api.types import is_numeric_dtype
from sklearn.metrics import accuracy_score, classification_report from sklearn.metrics import classification_report
import plotly.express as px import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
import joblib import joblib
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
import dice_ml
import plotly.graph_objects as go
clas_report = None
FILE_NAME = "dataset.csv" FILE_NAME = "dataset.csv"
PROCESS_FILE_NAME = "dataset_preprocessed.csv" PROCESS_FILE_NAME = "dataset_preprocessed.csv"
def home(request): def home(request):
global clas_report
# request.session.flush() # request.session.flush()
if "fig" in request.session: if "fig" in request.session:
fig = request.session.get("fig") fig = request.session.get("fig")
@ -43,6 +42,26 @@ def home(request):
else: else:
tsne = None tsne = None
if "cfrow_og" in request.session:
cfrow_og = request.session.get("cfrow_og")
else:
cfrow_og = None
if "cfrow_cf" in request.session:
cfrow_cf = request.session.get("cfrow_cf")
else:
cfrow_cf = None
if "cfdf_rows" in request.session:
cfdf_rows = request.session.get("cfdf_rows")
else:
cfdf_rows = None
if "clas_report" in request.session:
clas_report = request.session.get("clas_report")
else:
clas_report = None
if "excel_file_name" in request.session: if "excel_file_name" in request.session:
excel_file_name = request.session.get("excel_file_name") excel_file_name = request.session.get("excel_file_name")
else: else:
@ -60,15 +79,91 @@ def home(request):
excel_file_name_preprocessed = PROCESS_FILE_NAME excel_file_name_preprocessed = PROCESS_FILE_NAME
request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed
# ajax request condition
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
if request.POST.get("action") == "click_graph":
# ajax request for graph click
# given the ide of the clicked point (through ajax request)
# locate the respective dataframe row
df = pd.read_csv(excel_file_name_preprocessed)
id = request.POST["row"]
row = df.iloc[[int(id)]]
projections = request.session.get("tsne_projection")
# projections array is a list of pairs with the (x, y)
# coordinates for a point in tsne. These are actual absolute
# coordinates and not SVG.
# Now save the info for use in the future
request.session["clicked_point"] = projections[int(id)]
request.session["cfrow_id"] = request.POST["row"]
request.session["cfrow_og"] = row.to_html()
context = {"row": row.to_html()}
elif request.POST.get("action") == "counterfactual_select":
# if <select> element is used, and a specific counterfactual
# is inquired to be demonstrated:
df = pd.read_csv("counterfactuals.csv")
id = request.POST["row"]
row = df.iloc[[int(id)]]
projections = request.session.get("tsne_projection")
fig = joblib.load("tsne_cfs.sav")
# tsne_cfs is a merged scatter of 2 other scatters
# scatter[0] contains all the points of tsne
# scatter[1] contains all the counterfactual points
# + the point for which counterfactuals were
# computed...
# blur out all the points of the first scatter
fig.data[0].update(
opacity=0.3,
)
# from the second scatter (scatter[1]) locate the clicked_point
# coordinates and keep the index
l = fig.data[1]
clicked_id = -1
for clicked_id, item in enumerate(list(zip(l.x, l.y))):
if (
item[0] == request.session.get("clicked_point")[0]
and item[1] == request.session.get("clicked_point")[1]
):
break
# id is the index of the respective counter factual
# clicked_id is the index of the original point
# Blur all the points apart from these 2
fig.data[1].update(
selectedpoints=[id, clicked_id],
unselected=dict(
marker=dict(
opacity=0.3,
)
),
)
context = {"row": row.to_html(), "fig": fig.to_html()}
elif request.POST.get("action") == "reset_graph":
fig = request.session.get("tsne")
context = {"fig": fig}
return HttpResponse(json.dumps(context))
df = pd.DataFrame() df = pd.DataFrame()
if request.method == "POST": if request.method == "POST":
if "csv" in request.POST: if "csv" in request.POST:
excel_file = request.FILES["excel_file"] excel_file = request.FILES["excel_file"]
excel_file_name = request.FILES["excel_file"].name excel_file_name = request.FILES["excel_file"].name
fig = None fig = None
fig2 = None fig2 = None
pca = None
tsne = None
clas_report = None
cfrow_cf = None
cfrow_og = None
# here we dont use the name of the file since the # here we dont use the name of the file since the
# uploaded file is not yet saved # uploaded file is not yet saved
@ -79,11 +174,11 @@ def home(request):
# fs.save(excel_file_name, excel_file) # fs.save(excel_file_name, excel_file)
df = pd.read_csv(excel_file) df = pd.read_csv(excel_file)
df.drop(["id"], axis=1, inplace=True) # df.drop(["id"], axis=1, inplace=True)
df.to_csv(excel_file_name, index=False) df.to_csv(excel_file_name, index=False)
feature1 = df.columns[0] feature1 = df.columns[3]
feature2 = df.columns[1] feature2 = df.columns[2]
request.session["feature1"] = feature1 request.session["feature1"] = feature1
request.session["feature2"] = feature2 request.session["feature2"] = feature2
@ -108,7 +203,8 @@ def home(request):
mode = request.POST.get("colorRadio") mode = request.POST.get("colorRadio")
model = request.POST.get("model") model = request.POST.get("model")
test_size = float(request.POST.get("split_input")) test_size = float(request.POST.get("split_input"))
print(test_size, mode, model)
request.session["model"] = model
if mode == "train": if mode == "train":
if model == "logit": if model == "logit":
con = training(excel_file_name_preprocessed, "logit", test_size) con = training(excel_file_name_preprocessed, "logit", test_size)
@ -122,6 +218,9 @@ def home(request):
elif model == "svm": elif model == "svm":
con = training(excel_file_name_preprocessed, "svm", test_size) con = training(excel_file_name_preprocessed, "svm", test_size)
elif model == "rf":
con = training(excel_file_name_preprocessed, "rf", test_size)
fig2 = con["fig2"] fig2 = con["fig2"]
clas_report = con["clas_report"].to_html() clas_report = con["clas_report"].to_html()
elif mode == "test": elif mode == "test":
@ -137,6 +236,9 @@ def home(request):
elif model == "svm": elif model == "svm":
con = testing(excel_file_name_preprocessed, "svm") con = testing(excel_file_name_preprocessed, "svm")
elif model == "rf":
con = training(excel_file_name_preprocessed, "rf", test_size)
fig2 = con["fig2"] fig2 = con["fig2"]
clas_report = con["clas_report"].to_html() clas_report = con["clas_report"].to_html()
@ -146,21 +248,16 @@ def home(request):
# if file for preprocessing does not exist create it # if file for preprocessing does not exist create it
# also apply basic preprocessing # also apply basic preprocessing
if os.path.exists(excel_file_name_preprocessed) == False: if os.path.exists(excel_file_name_preprocessed) == False:
# generate filename # generate filename
idx = excel_file_name.index(".") idx = excel_file_name.index(".")
excel_file_name_preprocessed = ( excel_file_name_preprocessed = (
excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:] excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:]
) )
# save file for preprocessing # save file for preprocessing
preprocess_df = pd.read_csv(excel_file_name) preprocess_df = pd.read_csv(excel_file_name)
fs = FileSystemStorage() # defaults to MEDIA_ROOT
request.session["excel_file_name_preprocessed"] = ( request.session["excel_file_name_preprocessed"] = (
excel_file_name_preprocessed excel_file_name_preprocessed
) )
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
preprocess_df.drop( preprocess_df.drop(
["perimeter_mean", "area_mean"], axis=1, inplace=True ["perimeter_mean", "area_mean"], axis=1, inplace=True
) )
@ -180,18 +277,21 @@ def home(request):
inplace=True, inplace=True,
) )
# preprocess_df.drop(["id"], axis=1, inplace=True)
le = LabelEncoder() le = LabelEncoder()
preprocess_df["diagnosis"] = le.fit_transform( preprocess_df["diagnosis"] = le.fit_transform(
preprocess_df["diagnosis"] preprocess_df["diagnosis"]
) )
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
else: else:
preprocess_df = pd.read_csv(excel_file_name_preprocessed) preprocess_df = pd.read_csv(excel_file_name_preprocessed)
preprocess(preprocess_df, value_list, excel_file_name_preprocessed) preprocess(preprocess_df, value_list, excel_file_name_preprocessed)
preprocess_df.drop(["id"], axis=1, inplace=True)
# PCA
pca = PCA() pca = PCA()
pca.fit(preprocess_df) pca.fit(preprocess_df.loc[:, "radius_mean":])
exp_var_cumul = np.cumsum(pca.explained_variance_ratio_) exp_var_cumul = np.cumsum(pca.explained_variance_ratio_)
pca = px.area( pca = px.area(
x=range(1, exp_var_cumul.shape[0] + 1), x=range(1, exp_var_cumul.shape[0] + 1),
@ -199,31 +299,112 @@ def home(request):
labels={"x": "# Components", "y": "Explained Variance"}, labels={"x": "# Components", "y": "Explained Variance"},
).to_html() ).to_html()
features = preprocess_df.loc[:, :"compactness_se"] # tSNE
tsne = TSNE(n_components=2, random_state=0) tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(features) projections = tsne.fit_transform(
preprocess_df.drop(["diagnosis"], axis=1).values
)
tsne_df = pd.DataFrame(
{
"0": projections[:, 0],
"1": projections[:, 1],
"diagnosis": preprocess_df["diagnosis"],
}
)
tsne = px.scatter( tsne = px.scatter(
projections, tsne_df,
x=0, x="0",
y=1, y="1",
color=preprocess_df.diagnosis, color="diagnosis",
labels={"color": "diagnosis"}, color_continuous_scale=px.colors.sequential.Rainbow,
).to_html() )
tsne.update_layout(clickmode="event+select")
# tsne_opacity.update_layout(clickmode="event+select")
pickle.dump(tsne, open("tsne.sav", "wb"))
tsne = tsne.to_html()
request.session["tsne_projection"] = projections.tolist()
elif "cf" in request.POST:
excel_file_name_preprocessed = request.session.get(
"excel_file_name_preprocessed"
)
df = pd.read_csv(excel_file_name_preprocessed)
df_id = request.session.get("cfrow_id")
model = request.session.get("model")
row = df.iloc[[int(df_id)]]
counterfactuals(row, model, excel_file_name_preprocessed)
# get coordinates of the clicked point (saved during the actual click)
clicked_point = request.session.get("clicked_point")
clicked_point_df = pd.DataFrame(
{
"0": clicked_point[0],
"1": clicked_point[1],
"diagnosis": row.diagnosis,
}
)
clicked_point_df.reset_index(drop=True)
# tSNE
cf_df = pd.read_csv("counterfactuals.csv")
# get rows count
request.session["cfdf_rows"] = cf_df.index.values.tolist()
# select a cf randomly for demonstration
request.session["cfrow_cf"] = cf_df.iloc[:1].to_html()
df_merged = pd.concat(
[cf_df, df.drop("id", axis=1)], ignore_index=True, axis=0
)
tsne_cf = TSNE(n_components=2, random_state=0)
projections = tsne_cf.fit_transform(
df_merged.drop(["diagnosis"], axis=1).values
)
cf_df = pd.DataFrame(
{
"0": projections[:3, 0],
"1": projections[:3, 1],
"diagnosis": cf_df.diagnosis.iloc[:3],
}
)
cf_df = pd.concat([cf_df, clicked_point_df], ignore_index=True, axis=0)
tsne = joblib.load("tsne.sav")
cf_s = px.scatter(
cf_df,
x="0",
y="1",
color="diagnosis",
color_continuous_scale=px.colors.sequential.Rainbow,
)
cf_s.update_traces(
marker=dict(
size=10,
symbol="circle",
)
)
tsne.add_trace(cf_s.data[0])
pickle.dump(tsne, open("tsne_cfs.sav", "wb"))
tsne = tsne.to_html()
else: else:
if os.path.exists(excel_file_name) == False: if os.path.exists(excel_file_name) == False:
excel_file_name = "dataset.csv" excel_file_name = "dataset.csv"
request.session["excel_file_name"] = excel_file_name request.session["excel_file_name"] = excel_file_name
df = pd.read_csv(excel_file_name) df = pd.read_csv(excel_file_name)
fig2 = None
# just random columns to plot # just random columns to plot
feature1 = df.columns[0] feature1 = df.columns[3]
feature2 = df.columns[1] feature2 = df.columns[2]
request.session["feature1"] = feature1 request.session["feature1"] = feature1
request.session["feature2"] = feature2 request.session["feature2"] = feature2
fig = stats( fig = stats(
@ -237,10 +418,11 @@ def home(request):
request.session["fig2"] = fig2 request.session["fig2"] = fig2
request.session["pca"] = pca request.session["pca"] = pca
request.session["tsne"] = tsne request.session["tsne"] = tsne
request.session["clas_report"] = clas_report
data_to_display = df[:5].to_html() data_to_display = df[:10].to_html(index=False)
request.session["data_to_display"] = data_to_display request.session["data_to_display"] = data_to_display
labels = df.columns labels = df.columns[2:]
context = { context = {
"data_to_display": data_to_display, "data_to_display": data_to_display,
"excel_file": excel_file_name, "excel_file": excel_file_name,
@ -252,6 +434,9 @@ def home(request):
"clas_report": clas_report, "clas_report": clas_report,
"pca": pca, "pca": pca,
"tsne": tsne, "tsne": tsne,
"cfrow_og": cfrow_og,
"cfrow_cf": cfrow_cf,
"cfdf_rows": cfdf_rows,
} }
return render(request, "base/home.html", context) return render(request, "base/home.html", context)
@ -278,6 +463,9 @@ def stats(name, feature1, feature2):
else: else:
# they both are categorical so do scatter # they both are categorical so do scatter
fig = px.histogram(df, x=feature1, color=feature2) fig = px.histogram(df, x=feature1, color=feature2)
fig.update_layout(clickmode="event+select")
fig = fig.to_html(full_html=False) fig = fig.to_html(full_html=False)
return fig return fig
@ -285,13 +473,14 @@ def stats(name, feature1, feature2):
def preprocess(data, value_list, name): def preprocess(data, value_list, name):
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
ids = data["id"]
data = data.drop(["id"], axis=1)
for type in value_list: for type in value_list:
if type == "std": if type == "std":
# define standard scaler # define standard scaler
scaler = StandardScaler() scaler = StandardScaler()
y = data["diagnosis"] y = data["diagnosis"]
if is_numeric_dtype(data["diagnosis"]):
# if class column is numeric do not # if class column is numeric do not
# apply preprocessing # apply preprocessing
data = data.drop(["diagnosis"], axis=1) data = data.drop(["diagnosis"], axis=1)
@ -299,13 +488,15 @@ def preprocess(data, value_list, name):
# transform data # transform data
cols = data.select_dtypes(np.number).columns cols = data.select_dtypes(np.number).columns
data[cols] = scaler.fit_transform(data[cols]) data[cols] = scaler.fit_transform(data[cols])
y = y.to_frame() data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
data = data.join(y)
if type == "onehot": if type == "onehot":
data = pd.get_dummies(data) data = pd.get_dummies(data)
if type == "imp": if type == "imp":
y = data["diagnosis"]
data = data.drop(["diagnosis"], axis=1)
data_numeric = data.select_dtypes(exclude=["object"]) data_numeric = data.select_dtypes(exclude=["object"])
data_categorical = data.select_dtypes(exclude=["number"]) data_categorical = data.select_dtypes(exclude=["number"])
imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent") imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent")
@ -318,17 +509,20 @@ def preprocess(data, value_list, name):
[data_numeric, data_categorical], axis=1, ignore_index=False [data_numeric, data_categorical], axis=1, ignore_index=False
) )
data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
new = pd.concat([ids.to_frame(), data], axis=1, ignore_index=False)
os.remove(name) os.remove(name)
data.to_csv(name, index=False) new.to_csv(name, index=False)
return return
def training(name, type, test_size=0.7): def training(name, type, test_size=0.7):
data = pd.read_csv(name) data = pd.read_csv(name)
X = data.drop("id", axis=1)
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
y = data["diagnosis"] y = X["diagnosis"]
X = data.drop("diagnosis", axis=1) X = X.drop("diagnosis", axis=1)
X_train, X_test, y_train, y_test = train_test_split( X_train, X_test, y_train, y_test = train_test_split(
X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42 X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42
) )
@ -377,6 +571,16 @@ def training(name, type, test_size=0.7):
importance = svc.coef_[0] importance = svc.coef_[0]
model = svc model = svc
if "rf" == type:
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
filename = "rf.sav"
importance = rf.feature_importances_
model = rf
clas_report = classification_report(y_test, y_pred, output_dict=True) clas_report = classification_report(y_test, y_pred, output_dict=True)
clas_report = pd.DataFrame(clas_report).transpose() clas_report = pd.DataFrame(clas_report).transpose()
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False) clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
@ -423,6 +627,13 @@ def testing(name, type):
importance = svc.coef_[0] importance = svc.coef_[0]
model = svc model = svc
if "rf" == type:
filename = "rf.sav"
rf = joblib.load(filename)
y_pred = rf.predict(X_test)
importance = rf.feature_importances_
model = rf
clas_report = classification_report(y_test, y_pred, output_dict=True) clas_report = classification_report(y_test, y_pred, output_dict=True)
clas_report = pd.DataFrame(clas_report).transpose() clas_report = pd.DataFrame(clas_report).transpose()
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False) clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
@ -433,3 +644,73 @@ def testing(name, type):
"clas_report": clas_report, "clas_report": clas_report,
} }
return con return con
def counterfactuals(row, model, preprocessed_file):
df = pd.read_csv(preprocessed_file)
df = df.drop("id", axis=1)
row = row.drop(["diagnosis", "id"], axis=1)
if "logit" == model:
filename = "lg.sav"
clf = joblib.load(filename)
model = clf
if "xgb" == model:
filename = "xgb.sav"
xgb = joblib.load(filename)
model = xgb
if "dt" == model:
filename = "dt.sav"
dt = joblib.load(filename)
model = dt
if "svm" == model:
filename = "svc.sav"
svc = joblib.load(filename)
model = svc
if "rf" == model:
filename = "rf.sav"
rf = joblib.load(filename)
model = rf
d = dice_ml.Data(
dataframe=df,
continuous_features=[
"radius_mean",
"texture_mean",
"smoothness_mean",
"compactness_mean",
"concavity_mean",
"symmetry_mean",
"fractal_dimension_mean",
"radius_se",
"texture_se",
"smoothness_se",
"compactness_se",
"concavity_se",
"concave points_se",
"symmetry_se",
"fractal_dimension_se",
"compactness_worst",
"concavity_worst",
"concave points_worst",
"fractal_dimension_worst",
],
outcome_name="diagnosis",
)
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)
dice_exp = exp.generate_counterfactuals(
row, # The data from the 1st row of our dataframe
total_CFs=3, # Total number of Counterfactual Examples we want to print out. There can be multiple.
desired_class="opposite", # We want to convert the quality to the opposite one.
)
# save cfs
dice_exp.cf_examples_list[0].final_cfs_df.to_csv(
path_or_buf="counterfactuals.csv", index=False
)
return dice_exp._cf_examples_list

Binary file not shown.

3
extremum/routing.py Normal file

@ -0,0 +1,3 @@
from channels.routing import ProtocolTypeRouter
application = ProtocolTypeRouter({})

@ -40,6 +40,9 @@ INSTALLED_APPS = [
"django.contrib.staticfiles", "django.contrib.staticfiles",
"base.apps.BaseConfig", "base.apps.BaseConfig",
"bootstrap5", "bootstrap5",
"django_plotly_dash.apps.DjangoPlotlyDashConfig",
"channels",
"channels_redis",
] ]
MIDDLEWARE = [ MIDDLEWARE = [
@ -105,6 +108,7 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/5.0/topics/i18n/ # https://docs.djangoproject.com/en/5.0/topics/i18n/
X_FRAME_OPTIONS = 'SAMEORIGIN'
LANGUAGE_CODE = "en-us" LANGUAGE_CODE = "en-us"
@ -114,6 +118,32 @@ USE_I18N = True
USE_TZ = True USE_TZ = True
CRISPY_TEMPLATE = "bootsrap5"
ASGI_APPLICATION = "extremum.routing.applications"
CHANNEL_LAYERS = {
'default': {
'BACKEND': 'channels_redis.core.RedisChannelLayer',
'CONFIG': {
'hosts': [('127.0.0.1', 6379),],
},
},
}
SSTATICFILRES_FINDERS = [
"django.contrib.staticfiles.finders.FileSystemFinder",
"django.contrib.staticfiles.finders.AppDirectoriesFinder",
"django_plotly_dash.finder.DashAssetFinder",
"dajango_plotly_dash.finders.DashComponentsFinder",
]
PLOTLY_COMPONENTS = [
"dash_core_components",
"dash_html_components",
"dash_renderer",
"dpd_components",
]
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/5.0/howto/static-files/ # https://docs.djangoproject.com/en/5.0/howto/static-files/

@ -11,7 +11,10 @@
<title>EXTREMUM</title> <title>EXTREMUM</title>
<meta name="viewport" content="'width=device-width, initial-scale=1" /> <meta name="viewport" content="'width=device-width, initial-scale=1" />
<link rel="stylesheet" href="{% static 'css/style.css' %}"> <link rel="stylesheet" href="{% static 'css/style.css' %}">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script> <script src="https://cdn.plot.ly/plotly-latest.min.js" charset="utf-8"></script>
<script src="https://code.jquery.com/jquery-3.7.1.js" integrity="sha256-eKhayi8LEQwp4NKxN+CfCh+3qOVUtJn3QNZ0TciWLP4=" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/svg.js/3.2.0/svg.min.js" charset="utf-8"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css">
</head> </head>
@ -29,12 +32,13 @@
<br> <br>
<br> <br>
<body onscroll="setScroll()" onload="setScreen()"> <body onscroll="setScroll()" onload="setScreen()">
{% block content %} {% block content %}
{% endblock content %} {% endblock content %}
<script src="{% static 'js/click_on_graph.js' %}"></script>
<script src="{% static 'js/hide_seek.js' %}"></script> <script src="{% static 'js/hide_seek.js' %}"></script>
<script src="{% static 'js/slider.js' %}"></script> <script src="{% static 'js/slider.js' %}"></script>
<script src="{% static 'js/keep_scroll_on_load.js' %}"></script> <script src="{% static 'js/keep_scroll_on_load.js' %}"></script>
<script src="{% static 'js/counterfactuals.js' %}"></script>
<script src="{% static 'js/click_reset.js' %}"></script>
</body> </body>
</html> </html>