Counterfactuals demonstration and computation for tabular
This commit is contained in:
parent
6047fad7ce
commit
94c682362f
base
db.sqlite3extremum
templates
Binary file not shown.
Binary file not shown.
Binary file not shown.
54
base/dash_apps/finished_apps/tsne.py
Normal file
54
base/dash_apps/finished_apps/tsne.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from base import views
|
||||||
|
from dash import dcc
|
||||||
|
from dash import html
|
||||||
|
from dash.dependencies import Input, Output
|
||||||
|
from django_plotly_dash import DjangoDash
|
||||||
|
import plotly.express as px
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
import json
|
||||||
|
|
||||||
|
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
||||||
|
app = DjangoDash("Tsne", external_stylesheets=external_stylesheets)
|
||||||
|
excel_file_name_preprocessed = "breast-cancer_preprocessed.csv"
|
||||||
|
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
|
||||||
|
preprocess_df.drop(["id"], axis=1, inplace=True)
|
||||||
|
features = preprocess_df.loc[:, :"compactness_se"]
|
||||||
|
|
||||||
|
tsne = TSNE(n_components=2, random_state=0)
|
||||||
|
projections = tsne.fit_transform(features)
|
||||||
|
tsne = px.scatter(
|
||||||
|
projections,
|
||||||
|
x=0,
|
||||||
|
y=1,
|
||||||
|
color=preprocess_df.diagnosis.astype(str),
|
||||||
|
labels={"color": "diagnosis"},
|
||||||
|
)
|
||||||
|
tsne.update_layout(clickmode="event+select")
|
||||||
|
|
||||||
|
app.layout = html.Div(
|
||||||
|
[
|
||||||
|
dcc.Graph(id="scatter-plot", config={"displayModeBar": False}, figure=tsne),
|
||||||
|
html.Div(
|
||||||
|
className="row",
|
||||||
|
children=[
|
||||||
|
html.Div(
|
||||||
|
[
|
||||||
|
dcc.Input(
|
||||||
|
id="click-data",
|
||||||
|
type="hidden",
|
||||||
|
className="three columns",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.callback(Output("click-data", "children"), Input("scatter-plot", "clickData"))
|
||||||
|
def display_click_data(clickData, request):
|
||||||
|
# dictionary of list of dictionary ???
|
||||||
|
views.counterfactuals(clickData, request)
|
||||||
|
return json.dumps(clickData, indent=2)
|
16
base/migrations/0003_delete_document.py
Normal file
16
base/migrations/0003_delete_document.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Generated by Django 4.2.13 on 2024-06-17 08:51
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("base", "0002_document_delete_upload"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.DeleteModel(
|
||||||
|
name="Document",
|
||||||
|
),
|
||||||
|
]
|
BIN
base/migrations/__pycache__/0003_delete_document.cpython-310.pyc
Normal file
BIN
base/migrations/__pycache__/0003_delete_document.cpython-310.pyc
Normal file
Binary file not shown.
@ -288,12 +288,13 @@ nav input {
|
|||||||
height: min-content;
|
height: min-content;
|
||||||
max-height: 40%;
|
max-height: 40%;
|
||||||
max-width: 35%;
|
max-width: 35%;
|
||||||
|
border: 1px solid #bbb;
|
||||||
|
padding: 5px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dataframe {
|
.dataframe {
|
||||||
font-size: 9pt;
|
font-size: 9pt;
|
||||||
font-family: Arial;
|
font-family: Arial;
|
||||||
border-collapse: collapse;
|
|
||||||
font-size: 0.9em;
|
font-size: 0.9em;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,6 +307,13 @@ nav input {
|
|||||||
.dataframe td {
|
.dataframe td {
|
||||||
padding: 12px 15px;
|
padding: 12px 15px;
|
||||||
text-align: left;
|
text-align: left;
|
||||||
|
border: black;
|
||||||
|
border-collapse: separate;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dataframe .clickedrow th,
|
||||||
|
.dataframe .clickedrow td {
|
||||||
|
background-color: #c6bdf8;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dataframe tbody tr {
|
.dataframe tbody tr {
|
||||||
@ -319,31 +327,36 @@ nav input {
|
|||||||
.dataframe tbody tr:last-of-type {
|
.dataframe tbody tr:last-of-type {
|
||||||
border-bottom: 2px solid #009879;
|
border-bottom: 2px solid #009879;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.dataframe tbody tr:hover {
|
||||||
|
background-color: #e8e5f9;
|
||||||
|
}
|
||||||
|
|
||||||
/* plotly toolbar */
|
/* plotly toolbar */
|
||||||
.modebar {
|
.modebar {
|
||||||
display: none !important;
|
display: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* button */
|
/* button */
|
||||||
/* CSS */
|
/* CSS */
|
||||||
.button-6 {
|
.button-6 {
|
||||||
align-items: center;
|
align-items: center;
|
||||||
background-color: #FFFFFF;
|
background-color: #ffffff;
|
||||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||||
border-radius: .25rem;
|
border-radius: 0.25rem;
|
||||||
box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0;
|
box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0;
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
color: rgba(0, 0, 0, 0.85);
|
color: rgba(0, 0, 0, 0.85);
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
display: inline-flex;
|
display: inline-flex;
|
||||||
font-family: system-ui,-apple-system,system-ui,"Helvetica Neue",Helvetica,Arial,sans-serif;
|
font-family: system-ui, -apple-system, system-ui, "Helvetica Neue", Helvetica,
|
||||||
|
Arial, sans-serif;
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
line-height: 1.25;
|
line-height: 1.25;
|
||||||
margin: 0;
|
margin: 0;
|
||||||
min-height: 3rem;
|
min-height: 3rem;
|
||||||
padding: calc(.875rem - 1px) calc(1.5rem - 1px);
|
padding: calc(0.875rem - 1px) calc(1.5rem - 1px);
|
||||||
position: relative;
|
position: relative;
|
||||||
text-decoration: none;
|
text-decoration: none;
|
||||||
transition: all 250ms;
|
transition: all 250ms;
|
||||||
@ -366,7 +379,7 @@ nav input {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.button-6:active {
|
.button-6:active {
|
||||||
background-color: #F0F0F1;
|
background-color: #f0f0f1;
|
||||||
border-color: rgba(0, 0, 0, 0.15);
|
border-color: rgba(0, 0, 0, 0.15);
|
||||||
box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px;
|
box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px;
|
||||||
color: rgba(0, 0, 0, 0.65);
|
color: rgba(0, 0, 0, 0.65);
|
||||||
@ -379,7 +392,7 @@ nav input {
|
|||||||
display: flex;
|
display: flex;
|
||||||
flex-wrap: wrap;
|
flex-wrap: wrap;
|
||||||
border-radius: 0.5rem;
|
border-radius: 0.5rem;
|
||||||
background-color: #EEE;
|
background-color: #eee;
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06);
|
box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06);
|
||||||
padding: 0.25rem;
|
padding: 0.25rem;
|
||||||
@ -403,12 +416,32 @@ nav input {
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
border-radius: 0.5rem;
|
border-radius: 0.5rem;
|
||||||
border: none;
|
border: none;
|
||||||
padding: .5rem 0;
|
padding: 0.5rem 0;
|
||||||
color: rgba(51, 65, 85, 1);
|
color: rgba(51, 65, 85, 1);
|
||||||
transition: all .15s ease-in-out;
|
transition: all 0.15s ease-in-out;
|
||||||
}
|
}
|
||||||
|
|
||||||
.radio-inputs .radio input:checked + .name {
|
.radio-inputs .radio input:checked + .name {
|
||||||
background-color: #fff;
|
background-color: #fff;
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.loader {
|
||||||
|
width: 48px;
|
||||||
|
height: 48px;
|
||||||
|
border: 5px solid #fff;
|
||||||
|
border-bottom-color: #ff3d00;
|
||||||
|
border-radius: 50%;
|
||||||
|
display: inline-block;
|
||||||
|
box-sizing: border-box;
|
||||||
|
animation: rotation 1s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes rotation {
|
||||||
|
0% {
|
||||||
|
transform: rotate(0deg);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
108
base/static/js/click_on_graph.js
Normal file
108
base/static/js/click_on_graph.js
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
$(document).ready(function () {
|
||||||
|
// When click on graph (tsne)
|
||||||
|
// iterate through the elements and find the one with opacity 1
|
||||||
|
// the rest have opacity 0.2 when a button is clicked...
|
||||||
|
|
||||||
|
// Case 1: They are all opacity 1!
|
||||||
|
|
||||||
|
// In that case we cannot check all of them
|
||||||
|
// it would be inneficient,
|
||||||
|
// though if we just find one more we can determine
|
||||||
|
// that all the rest are also opacity: 1 and thus
|
||||||
|
// there was no click on a specific point
|
||||||
|
|
||||||
|
// Case 2: There was actually a click on a point!
|
||||||
|
|
||||||
|
// In that case we need to find that point. Poltly
|
||||||
|
// does not save all the points on one <g> element
|
||||||
|
// Instead on the particular example it would use 2
|
||||||
|
// <g> elements with multiple <points> in them.
|
||||||
|
// We iterate through all the collections and all the points
|
||||||
|
// until we find the point element with opacity 1.
|
||||||
|
// That way feels inefficient but in reality rarely will
|
||||||
|
// the last ever point of the last collection be selected
|
||||||
|
// and even then the complexity is just O(n). Not great
|
||||||
|
// but not bad either. It would help to be able to determine
|
||||||
|
// if a point is in collection i or j but I am not there
|
||||||
|
// yet.
|
||||||
|
|
||||||
|
document.getElementById('tsne').addEventListener('click', function () {
|
||||||
|
|
||||||
|
// remove original and counterfactual columns
|
||||||
|
if (document.getElementById("og_cf_row") && document.getElementById("og_cf_headers")) {
|
||||||
|
document.getElementById("og_cf_row").style.display = "none";
|
||||||
|
document.getElementById("og_cf_headers").style.display = "none";
|
||||||
|
}
|
||||||
|
|
||||||
|
var tsne = document.getElementById('tsne');
|
||||||
|
var points_array = tsne.getElementsByClassName("points")
|
||||||
|
var counter_break = 0;
|
||||||
|
var final_position = 0;
|
||||||
|
var j;
|
||||||
|
var i;
|
||||||
|
|
||||||
|
for (i = 0; i < points_array.length; i++) {
|
||||||
|
var points = points_array[i]
|
||||||
|
var children = points.children
|
||||||
|
if (getComputedStyle(children[0]).opacity == 1 && getComputedStyle(children[1]).opacity == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
console.log(children.length)
|
||||||
|
|
||||||
|
for (j = 0; j < children.length; j++) {
|
||||||
|
var st = children[j]
|
||||||
|
opacity = getComputedStyle(st).opacity
|
||||||
|
if (opacity == 1) {
|
||||||
|
console.log("i: ", i)
|
||||||
|
console.log("J: ", j)
|
||||||
|
console.log("children.length: ", children.length)
|
||||||
|
counter_break++;
|
||||||
|
final_position += j;
|
||||||
|
console.log("final_position: ", final_position)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (counter_break) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
final_position += j;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (counter_break == 0) {
|
||||||
|
$("#cfrow").remove();
|
||||||
|
document.getElementById("cftable").style.display = "none";
|
||||||
|
document.getElementById("cfbtn").style.display = "none";
|
||||||
|
} else if (counter_break == 1) {
|
||||||
|
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
|
||||||
|
$.ajax({
|
||||||
|
method: 'POST',
|
||||||
|
url: '',
|
||||||
|
headers: { 'X-CSRFToken': csrftoken },
|
||||||
|
data: { 'row': final_position, 'action': "click_graph" },
|
||||||
|
success: function (ret) {
|
||||||
|
|
||||||
|
row = JSON.parse(ret)
|
||||||
|
row = row["row"]
|
||||||
|
|
||||||
|
var tb = document.createElement('table');
|
||||||
|
tb.innerHTML = row.trim();
|
||||||
|
|
||||||
|
tb.setAttribute("id", "cfrow");
|
||||||
|
tb.setAttribute("class", "dataframe")
|
||||||
|
if (document.getElementById("cfrow")) {
|
||||||
|
$("#cfrow").remove();
|
||||||
|
document.getElementById("cftable").style.display = "none";
|
||||||
|
}
|
||||||
|
$("#cftable").append(tb);
|
||||||
|
document.getElementById("cftable").style.display = "block";
|
||||||
|
document.getElementById("cfbtn").style.display = "block";
|
||||||
|
},
|
||||||
|
error: function (row) {
|
||||||
|
console.log("it didnt work");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else if (counter_break > 1) {
|
||||||
|
console.log("All opacity 1")
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
30
base/static/js/click_on_row.js
Normal file
30
base/static/js/click_on_row.js
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
$(document).ready(function () {
|
||||||
|
//Highlight clicked row
|
||||||
|
$('.dataframe td').on('click', function () {
|
||||||
|
|
||||||
|
// Remove previous highlight class
|
||||||
|
$(this).closest('.dataframe').find('tr.clickedrow').removeClass('clickedrow');
|
||||||
|
// add highlight to the parent tr of the clicked td
|
||||||
|
$(this).closest('tr').addClass("clickedrow");
|
||||||
|
|
||||||
|
// fetch row
|
||||||
|
var $row = this.closest('tr')
|
||||||
|
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
|
||||||
|
$.ajax({
|
||||||
|
method: 'POST',
|
||||||
|
url: '',
|
||||||
|
headers: { 'X-CSRFToken': csrftoken },
|
||||||
|
data: { 'row': $row.innerHTML.trim().split("\n").map(line => line.trim()).join("\n") },
|
||||||
|
success: function (id) {
|
||||||
|
var myPlot = document.getElementById('tsne');
|
||||||
|
var paths = myPlot.querySelector('.points');
|
||||||
|
var path = paths.children[id];
|
||||||
|
path.setAttribute("style", "opacity: 1; stroke-width: 0px; fill: green; fill-opacity: 1; scale: 2.0;");
|
||||||
|
console.log(path)
|
||||||
|
},
|
||||||
|
error: function (id) {
|
||||||
|
console.log("it didnt work");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
27
base/static/js/click_reset.js
Normal file
27
base/static/js/click_reset.js
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
$(document).ready(function () {
|
||||||
|
//Highlight clicked row
|
||||||
|
document.getElementById('reset_div').addEventListener('click', function () {
|
||||||
|
// on click reset the graph
|
||||||
|
// if reset button exists, hide it
|
||||||
|
document.getElementById("reset_div").style.display = "none";
|
||||||
|
document.getElementById("loader_cf").style.display = "block";
|
||||||
|
|
||||||
|
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
|
||||||
|
$.ajax({
|
||||||
|
method: 'POST',
|
||||||
|
url: '',
|
||||||
|
headers: { 'X-CSRFToken': csrftoken },
|
||||||
|
data: {'action': "reset_graph" },
|
||||||
|
success: function (ret) {
|
||||||
|
document.getElementById("loader_cf").style.display = "none";
|
||||||
|
ret = JSON.parse(ret)
|
||||||
|
fig = ret["fig"]
|
||||||
|
document.getElementById("tsne").innerHTML = "";
|
||||||
|
$("#tsne").append(fig)
|
||||||
|
document.getElementById("reset_div").style.display = "none";
|
||||||
|
},
|
||||||
|
error: function (ret) {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
62
base/static/js/counterfactuals.js
Normal file
62
base/static/js/counterfactuals.js
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
$(document).ready(function () {
|
||||||
|
var mySelect = document.getElementById('cfrow_id');
|
||||||
|
mySelect.onchange = (event) => {
|
||||||
|
var rowId = event.target.value;
|
||||||
|
// inputText is the row id of the counter factual
|
||||||
|
// Send ajax request to backend and inquire for the
|
||||||
|
// respective counterfactual. When it is acquired
|
||||||
|
// replace current counterfactual view with that...
|
||||||
|
var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val();
|
||||||
|
document.getElementById("loader_cf").style.display = "block";
|
||||||
|
|
||||||
|
// if reset button exists, hide it
|
||||||
|
if (document.getElementById("reset_div"))
|
||||||
|
document.getElementById("reset_div").style.display = "none";
|
||||||
|
|
||||||
|
$("#show_image").show()
|
||||||
|
$.ajax({
|
||||||
|
method: 'POST',
|
||||||
|
url: '',
|
||||||
|
headers: { 'X-CSRFToken': csrftoken },
|
||||||
|
data: { 'row': rowId, 'action': "counterfactual_select" },
|
||||||
|
success: function (ret) {
|
||||||
|
// stop loader
|
||||||
|
document.getElementById("loader_cf").style.display = "none";
|
||||||
|
|
||||||
|
// add <input type="reset" value="Reset">
|
||||||
|
if (!document.getElementById("reset")) {
|
||||||
|
var reset = document.createElement('input')
|
||||||
|
reset.setAttribute("id", "reset")
|
||||||
|
reset.setAttribute("type", "reset")
|
||||||
|
reset.setAttribute("value", "Reset")
|
||||||
|
$("#reset_div").append(reset)
|
||||||
|
}
|
||||||
|
// if reset button has been created once just display it
|
||||||
|
document.getElementById("reset_div").style.display = "block";
|
||||||
|
|
||||||
|
|
||||||
|
var tb = document.createElement('table');
|
||||||
|
ret = JSON.parse(ret)
|
||||||
|
row = ret["row"]
|
||||||
|
tb.innerHTML = row.trim();
|
||||||
|
tb.setAttribute("id", "counterfactual_selected");
|
||||||
|
tb.setAttribute("class", "dataframe")
|
||||||
|
|
||||||
|
if (document.getElementById("counterfactual_selected")) {
|
||||||
|
$("#counterfactual_selected").remove();
|
||||||
|
}
|
||||||
|
$("#counterfactual").append(tb);
|
||||||
|
document.getElementById("counterfactual_selected").style.display = "block";
|
||||||
|
|
||||||
|
fig = ret["fig"]
|
||||||
|
document.getElementById("tsne").innerHTML = "";
|
||||||
|
$("#tsne").append(fig)
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
error: function (row) {
|
||||||
|
console.log("it didnt work");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
@ -3,7 +3,7 @@ function setScreen() {
|
|||||||
window.scrollTo(0, yScreen);
|
window.scrollTo(0, yScreen);
|
||||||
}
|
}
|
||||||
function setScroll() {
|
function setScroll() {
|
||||||
var yScroll = window.pageYOffset;
|
var yScroll = window.scrollY;
|
||||||
localStorage.setItem("yPos", yScroll);
|
localStorage.setItem("yPos", yScroll);
|
||||||
}
|
}
|
||||||
function clearScreen() {
|
function clearScreen() {
|
||||||
|
@ -2,6 +2,26 @@
|
|||||||
{% block content %}
|
{% block content %}
|
||||||
{% load static %}
|
{% load static %}
|
||||||
|
|
||||||
|
<!-- <div class="container-fluid">
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-sm d-flex justify-content-center" >
|
||||||
|
<div class="dataset">
|
||||||
|
<div class="form-check">
|
||||||
|
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault1">
|
||||||
|
<label class="form-check-label" for="flexRadioDefault1">
|
||||||
|
Breast Cancer Dataset
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="form-check">
|
||||||
|
<input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault2">
|
||||||
|
<label class="form-check-label" for="flexRadioDefault2">
|
||||||
|
Other
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div> -->
|
||||||
<div class="container-fluid">
|
<div class="container-fluid">
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<div class="col-sm d-flex justify-content-center">
|
<div class="col-sm d-flex justify-content-center">
|
||||||
@ -68,8 +88,10 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
<div class="container-fluid">
|
<div class="container-fluid">
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<section>
|
<section>
|
||||||
@ -113,24 +135,19 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
<div class="row">
|
<div class="row">
|
||||||
{% if pca %}
|
{% if pca %}
|
||||||
<div class="col-sm d-flex justify-content-center">
|
<div class="col d-flex justify-content-center" id="pca">
|
||||||
{{ pca|safe }}
|
{{ pca|safe }}
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if tsne %}
|
|
||||||
<div class="col-sm d-flex justify-content-center">
|
|
||||||
{{ tsne|safe }}
|
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<br>
|
|
||||||
<br>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
<div class="container-fluid">
|
<div class="container-fluid">
|
||||||
<form action="{% url 'home' %}" method="POST" id="traintest-form">
|
<form action="{% url 'home' %}" method="POST" id="traintest-form">
|
||||||
{% csrf_token %}
|
{% csrf_token %}
|
||||||
@ -147,7 +164,7 @@
|
|||||||
<option type="submit" value="logit" selected>Logistic Regression</option>
|
<option type="submit" value="logit" selected>Logistic Regression</option>
|
||||||
<option type="submit" value="xgb">XGBoost</option>
|
<option type="submit" value="xgb">XGBoost</option>
|
||||||
<option type="submit" value="dt">Decision Tree</option>
|
<option type="submit" value="dt">Decision Tree</option>
|
||||||
<option type="submit" value="rt">Random Forest</option>
|
<option type="submit" value="rf">Random Forest</option>
|
||||||
</select>
|
</select>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
@ -200,6 +217,7 @@
|
|||||||
{{ fig2|safe }}
|
{{ fig2|safe }}
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if clas_report %}
|
{% if clas_report %}
|
||||||
<div class="col-sm d-flex justify-content-center">
|
<div class="col-sm d-flex justify-content-center">
|
||||||
<div class="scrollit">
|
<div class="scrollit">
|
||||||
@ -209,4 +227,97 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<div class="container-fluid">
|
||||||
|
<div class="row">
|
||||||
|
<section>
|
||||||
|
<label style="display:flex;
|
||||||
|
flex-direction:column;
|
||||||
|
align-items: center;">
|
||||||
|
<h2>
|
||||||
|
<i class="fas fa-cog"></i> Counterfactuals
|
||||||
|
</h2>
|
||||||
|
<h5>
|
||||||
|
Pick a point on the graph
|
||||||
|
</h5>
|
||||||
|
|
||||||
|
</label>
|
||||||
|
</section>
|
||||||
|
<div class="col-sm d-flex justify-content-center">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% if tsne %}
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md justify-content-center" id="tsne">
|
||||||
|
{{ tsne|safe }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<div class="row">
|
||||||
|
<form action="{% url 'home' %}" method="POST" id="cf-form">
|
||||||
|
{% csrf_token %}
|
||||||
|
<div class="col-md d-flex justify-content-center">
|
||||||
|
<div class="scrollit" class="cftable" name="cfrow" id="cftable" style="border: 1px solid #bbb; padding: 5px; display: none; max-width: 65%;" >
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-sm d-flex justify-content-center">
|
||||||
|
<button id="cfbtn" type="submit" class="button-6" role="button" name="cf" form="cf-form" style="display:none;">Run Counterfactuals!</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% if cfrow_og and cfrow_cf %}
|
||||||
|
<div class="row" id="og_cf_headers">
|
||||||
|
<div class="col d-flex justify-content-center">
|
||||||
|
<h2>
|
||||||
|
Original data
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<span style="display: none; width: 48px; height: 48px;" id="reset_div">
|
||||||
|
</span>
|
||||||
|
<span class="loader" id="loader_cf" style="display: none;">
|
||||||
|
</span>
|
||||||
|
<div class="col d-flex justify-content-center" style="select:invalid { color: gray; }">
|
||||||
|
<label style="display:flex;
|
||||||
|
flex-direction:column;
|
||||||
|
align-items: center;">
|
||||||
|
<select id="cfrow_id" name="cfrow_id" style="scale:1.2;">
|
||||||
|
<option value="" disabled selected hidden>Pick counterfactual...</option>
|
||||||
|
{% for row in cfdf_rows %}
|
||||||
|
<option value={{row}}>Counterfactual {{row}}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="row" id="og_cf_row">
|
||||||
|
<div class="col-md d-flex justify-content-center" >
|
||||||
|
<div class="scrollit" id="original_data">
|
||||||
|
{{ cfrow_og|safe}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md d-flex justify-content-center" >
|
||||||
|
<div class="scrollit" id="counterfactual">
|
||||||
|
<div id="counterfactual_selected">
|
||||||
|
{{ cfrow_cf|safe }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
</div>
|
||||||
|
<div class="container-fluid">
|
||||||
|
</div>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
{% endblock content%}
|
{% endblock content%}
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
from django.urls import path, include
|
from django.urls import path, include
|
||||||
from . import views
|
from . import views
|
||||||
from . import models
|
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('', views.home, name="home"),
|
path('', views.home, name="home"),
|
||||||
path('preprocess', views.preprocess, name="preprocess"),
|
|
||||||
path('stats', views.stats, name="stats"),
|
|
||||||
]
|
]
|
@ -1,5 +0,0 @@
|
|||||||
def stats(feature1, feature2, df):
|
|
||||||
import plotly.express as px
|
|
||||||
fig = px.scatter(df, x=feature1, y=feature2, color='Churn')
|
|
||||||
fig = fig.to_html(full_html=False)
|
|
||||||
return fig
|
|
359
base/views.py
359
base/views.py
@ -1,27 +1,26 @@
|
|||||||
from django.shortcuts import render, redirect
|
from django.shortcuts import render, HttpResponse
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from django.core.files.storage import FileSystemStorage
|
import json
|
||||||
|
import IPython
|
||||||
import pickle, os
|
import pickle, os
|
||||||
from sklearn.impute import SimpleImputer
|
from sklearn.impute import SimpleImputer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pandas.api.types import is_string_dtype
|
|
||||||
from pandas.api.types import is_numeric_dtype
|
from pandas.api.types import is_numeric_dtype
|
||||||
from sklearn.metrics import accuracy_score, classification_report
|
from sklearn.metrics import classification_report
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import plotly.graph_objects as go
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
import joblib
|
import joblib
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
|
import dice_ml
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
|
||||||
clas_report = None
|
|
||||||
FILE_NAME = "dataset.csv"
|
FILE_NAME = "dataset.csv"
|
||||||
PROCESS_FILE_NAME = "dataset_preprocessed.csv"
|
PROCESS_FILE_NAME = "dataset_preprocessed.csv"
|
||||||
|
|
||||||
|
|
||||||
def home(request):
|
def home(request):
|
||||||
global clas_report
|
|
||||||
# request.session.flush()
|
# request.session.flush()
|
||||||
if "fig" in request.session:
|
if "fig" in request.session:
|
||||||
fig = request.session.get("fig")
|
fig = request.session.get("fig")
|
||||||
@ -43,6 +42,26 @@ def home(request):
|
|||||||
else:
|
else:
|
||||||
tsne = None
|
tsne = None
|
||||||
|
|
||||||
|
if "cfrow_og" in request.session:
|
||||||
|
cfrow_og = request.session.get("cfrow_og")
|
||||||
|
else:
|
||||||
|
cfrow_og = None
|
||||||
|
|
||||||
|
if "cfrow_cf" in request.session:
|
||||||
|
cfrow_cf = request.session.get("cfrow_cf")
|
||||||
|
else:
|
||||||
|
cfrow_cf = None
|
||||||
|
|
||||||
|
if "cfdf_rows" in request.session:
|
||||||
|
cfdf_rows = request.session.get("cfdf_rows")
|
||||||
|
else:
|
||||||
|
cfdf_rows = None
|
||||||
|
|
||||||
|
if "clas_report" in request.session:
|
||||||
|
clas_report = request.session.get("clas_report")
|
||||||
|
else:
|
||||||
|
clas_report = None
|
||||||
|
|
||||||
if "excel_file_name" in request.session:
|
if "excel_file_name" in request.session:
|
||||||
excel_file_name = request.session.get("excel_file_name")
|
excel_file_name = request.session.get("excel_file_name")
|
||||||
else:
|
else:
|
||||||
@ -60,15 +79,91 @@ def home(request):
|
|||||||
excel_file_name_preprocessed = PROCESS_FILE_NAME
|
excel_file_name_preprocessed = PROCESS_FILE_NAME
|
||||||
request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed
|
request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed
|
||||||
|
|
||||||
|
# ajax request condition
|
||||||
|
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
|
||||||
|
if request.POST.get("action") == "click_graph":
|
||||||
|
# ajax request for graph click
|
||||||
|
# given the ide of the clicked point (through ajax request)
|
||||||
|
# locate the respective dataframe row
|
||||||
|
df = pd.read_csv(excel_file_name_preprocessed)
|
||||||
|
id = request.POST["row"]
|
||||||
|
row = df.iloc[[int(id)]]
|
||||||
|
projections = request.session.get("tsne_projection")
|
||||||
|
|
||||||
|
# projections array is a list of pairs with the (x, y)
|
||||||
|
# coordinates for a point in tsne. These are actual absolute
|
||||||
|
# coordinates and not SVG.
|
||||||
|
# Now save the info for use in the future
|
||||||
|
request.session["clicked_point"] = projections[int(id)]
|
||||||
|
request.session["cfrow_id"] = request.POST["row"]
|
||||||
|
request.session["cfrow_og"] = row.to_html()
|
||||||
|
context = {"row": row.to_html()}
|
||||||
|
elif request.POST.get("action") == "counterfactual_select":
|
||||||
|
|
||||||
|
# if <select> element is used, and a specific counterfactual
|
||||||
|
# is inquired to be demonstrated:
|
||||||
|
df = pd.read_csv("counterfactuals.csv")
|
||||||
|
id = request.POST["row"]
|
||||||
|
row = df.iloc[[int(id)]]
|
||||||
|
|
||||||
|
projections = request.session.get("tsne_projection")
|
||||||
|
fig = joblib.load("tsne_cfs.sav")
|
||||||
|
|
||||||
|
# tsne_cfs is a merged scatter of 2 other scatters
|
||||||
|
# scatter[0] contains all the points of tsne
|
||||||
|
# scatter[1] contains all the counterfactual points
|
||||||
|
# + the point for which counterfactuals were
|
||||||
|
# computed...
|
||||||
|
|
||||||
|
# blur out all the points of the first scatter
|
||||||
|
fig.data[0].update(
|
||||||
|
opacity=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from the second scatter (scatter[1]) locate the clicked_point
|
||||||
|
# coordinates and keep the index
|
||||||
|
l = fig.data[1]
|
||||||
|
clicked_id = -1
|
||||||
|
for clicked_id, item in enumerate(list(zip(l.x, l.y))):
|
||||||
|
if (
|
||||||
|
item[0] == request.session.get("clicked_point")[0]
|
||||||
|
and item[1] == request.session.get("clicked_point")[1]
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
# id is the index of the respective counter factual
|
||||||
|
# clicked_id is the index of the original point
|
||||||
|
# Blur all the points apart from these 2
|
||||||
|
fig.data[1].update(
|
||||||
|
selectedpoints=[id, clicked_id],
|
||||||
|
unselected=dict(
|
||||||
|
marker=dict(
|
||||||
|
opacity=0.3,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
context = {"row": row.to_html(), "fig": fig.to_html()}
|
||||||
|
elif request.POST.get("action") == "reset_graph":
|
||||||
|
fig = request.session.get("tsne")
|
||||||
|
context = {"fig": fig}
|
||||||
|
|
||||||
|
return HttpResponse(json.dumps(context))
|
||||||
|
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
|
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
if "csv" in request.POST:
|
if "csv" in request.POST:
|
||||||
|
|
||||||
excel_file = request.FILES["excel_file"]
|
excel_file = request.FILES["excel_file"]
|
||||||
excel_file_name = request.FILES["excel_file"].name
|
excel_file_name = request.FILES["excel_file"].name
|
||||||
|
|
||||||
fig = None
|
fig = None
|
||||||
fig2 = None
|
fig2 = None
|
||||||
|
pca = None
|
||||||
|
tsne = None
|
||||||
|
clas_report = None
|
||||||
|
cfrow_cf = None
|
||||||
|
cfrow_og = None
|
||||||
|
|
||||||
# here we dont use the name of the file since the
|
# here we dont use the name of the file since the
|
||||||
# uploaded file is not yet saved
|
# uploaded file is not yet saved
|
||||||
@ -79,11 +174,11 @@ def home(request):
|
|||||||
# fs.save(excel_file_name, excel_file)
|
# fs.save(excel_file_name, excel_file)
|
||||||
|
|
||||||
df = pd.read_csv(excel_file)
|
df = pd.read_csv(excel_file)
|
||||||
df.drop(["id"], axis=1, inplace=True)
|
# df.drop(["id"], axis=1, inplace=True)
|
||||||
df.to_csv(excel_file_name, index=False)
|
df.to_csv(excel_file_name, index=False)
|
||||||
|
|
||||||
feature1 = df.columns[0]
|
feature1 = df.columns[3]
|
||||||
feature2 = df.columns[1]
|
feature2 = df.columns[2]
|
||||||
request.session["feature1"] = feature1
|
request.session["feature1"] = feature1
|
||||||
request.session["feature2"] = feature2
|
request.session["feature2"] = feature2
|
||||||
|
|
||||||
@ -108,7 +203,8 @@ def home(request):
|
|||||||
mode = request.POST.get("colorRadio")
|
mode = request.POST.get("colorRadio")
|
||||||
model = request.POST.get("model")
|
model = request.POST.get("model")
|
||||||
test_size = float(request.POST.get("split_input"))
|
test_size = float(request.POST.get("split_input"))
|
||||||
print(test_size, mode, model)
|
|
||||||
|
request.session["model"] = model
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
if model == "logit":
|
if model == "logit":
|
||||||
con = training(excel_file_name_preprocessed, "logit", test_size)
|
con = training(excel_file_name_preprocessed, "logit", test_size)
|
||||||
@ -122,6 +218,9 @@ def home(request):
|
|||||||
elif model == "svm":
|
elif model == "svm":
|
||||||
con = training(excel_file_name_preprocessed, "svm", test_size)
|
con = training(excel_file_name_preprocessed, "svm", test_size)
|
||||||
|
|
||||||
|
elif model == "rf":
|
||||||
|
con = training(excel_file_name_preprocessed, "rf", test_size)
|
||||||
|
|
||||||
fig2 = con["fig2"]
|
fig2 = con["fig2"]
|
||||||
clas_report = con["clas_report"].to_html()
|
clas_report = con["clas_report"].to_html()
|
||||||
elif mode == "test":
|
elif mode == "test":
|
||||||
@ -137,6 +236,9 @@ def home(request):
|
|||||||
elif model == "svm":
|
elif model == "svm":
|
||||||
con = testing(excel_file_name_preprocessed, "svm")
|
con = testing(excel_file_name_preprocessed, "svm")
|
||||||
|
|
||||||
|
elif model == "rf":
|
||||||
|
con = training(excel_file_name_preprocessed, "rf", test_size)
|
||||||
|
|
||||||
fig2 = con["fig2"]
|
fig2 = con["fig2"]
|
||||||
clas_report = con["clas_report"].to_html()
|
clas_report = con["clas_report"].to_html()
|
||||||
|
|
||||||
@ -146,21 +248,16 @@ def home(request):
|
|||||||
# if file for preprocessing does not exist create it
|
# if file for preprocessing does not exist create it
|
||||||
# also apply basic preprocessing
|
# also apply basic preprocessing
|
||||||
if os.path.exists(excel_file_name_preprocessed) == False:
|
if os.path.exists(excel_file_name_preprocessed) == False:
|
||||||
|
|
||||||
# generate filename
|
# generate filename
|
||||||
idx = excel_file_name.index(".")
|
idx = excel_file_name.index(".")
|
||||||
excel_file_name_preprocessed = (
|
excel_file_name_preprocessed = (
|
||||||
excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:]
|
excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:]
|
||||||
)
|
)
|
||||||
|
|
||||||
# save file for preprocessing
|
# save file for preprocessing
|
||||||
preprocess_df = pd.read_csv(excel_file_name)
|
preprocess_df = pd.read_csv(excel_file_name)
|
||||||
fs = FileSystemStorage() # defaults to MEDIA_ROOT
|
|
||||||
request.session["excel_file_name_preprocessed"] = (
|
request.session["excel_file_name_preprocessed"] = (
|
||||||
excel_file_name_preprocessed
|
excel_file_name_preprocessed
|
||||||
)
|
)
|
||||||
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
|
|
||||||
|
|
||||||
preprocess_df.drop(
|
preprocess_df.drop(
|
||||||
["perimeter_mean", "area_mean"], axis=1, inplace=True
|
["perimeter_mean", "area_mean"], axis=1, inplace=True
|
||||||
)
|
)
|
||||||
@ -180,18 +277,21 @@ def home(request):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# preprocess_df.drop(["id"], axis=1, inplace=True)
|
|
||||||
le = LabelEncoder()
|
le = LabelEncoder()
|
||||||
preprocess_df["diagnosis"] = le.fit_transform(
|
preprocess_df["diagnosis"] = le.fit_transform(
|
||||||
preprocess_df["diagnosis"]
|
preprocess_df["diagnosis"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
preprocess_df.to_csv(excel_file_name_preprocessed, index=False)
|
||||||
else:
|
else:
|
||||||
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
|
preprocess_df = pd.read_csv(excel_file_name_preprocessed)
|
||||||
|
|
||||||
preprocess(preprocess_df, value_list, excel_file_name_preprocessed)
|
preprocess(preprocess_df, value_list, excel_file_name_preprocessed)
|
||||||
|
preprocess_df.drop(["id"], axis=1, inplace=True)
|
||||||
|
|
||||||
|
# PCA
|
||||||
pca = PCA()
|
pca = PCA()
|
||||||
pca.fit(preprocess_df)
|
pca.fit(preprocess_df.loc[:, "radius_mean":])
|
||||||
exp_var_cumul = np.cumsum(pca.explained_variance_ratio_)
|
exp_var_cumul = np.cumsum(pca.explained_variance_ratio_)
|
||||||
pca = px.area(
|
pca = px.area(
|
||||||
x=range(1, exp_var_cumul.shape[0] + 1),
|
x=range(1, exp_var_cumul.shape[0] + 1),
|
||||||
@ -199,31 +299,112 @@ def home(request):
|
|||||||
labels={"x": "# Components", "y": "Explained Variance"},
|
labels={"x": "# Components", "y": "Explained Variance"},
|
||||||
).to_html()
|
).to_html()
|
||||||
|
|
||||||
features = preprocess_df.loc[:, :"compactness_se"]
|
# tSNE
|
||||||
|
|
||||||
tsne = TSNE(n_components=2, random_state=0)
|
tsne = TSNE(n_components=2, random_state=0)
|
||||||
projections = tsne.fit_transform(features)
|
projections = tsne.fit_transform(
|
||||||
|
preprocess_df.drop(["diagnosis"], axis=1).values
|
||||||
|
)
|
||||||
|
tsne_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"0": projections[:, 0],
|
||||||
|
"1": projections[:, 1],
|
||||||
|
"diagnosis": preprocess_df["diagnosis"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
tsne = px.scatter(
|
tsne = px.scatter(
|
||||||
projections,
|
tsne_df,
|
||||||
x=0,
|
x="0",
|
||||||
y=1,
|
y="1",
|
||||||
color=preprocess_df.diagnosis,
|
color="diagnosis",
|
||||||
labels={"color": "diagnosis"},
|
color_continuous_scale=px.colors.sequential.Rainbow,
|
||||||
).to_html()
|
)
|
||||||
|
|
||||||
|
tsne.update_layout(clickmode="event+select")
|
||||||
|
# tsne_opacity.update_layout(clickmode="event+select")
|
||||||
|
pickle.dump(tsne, open("tsne.sav", "wb"))
|
||||||
|
tsne = tsne.to_html()
|
||||||
|
request.session["tsne_projection"] = projections.tolist()
|
||||||
|
elif "cf" in request.POST:
|
||||||
|
|
||||||
|
excel_file_name_preprocessed = request.session.get(
|
||||||
|
"excel_file_name_preprocessed"
|
||||||
|
)
|
||||||
|
|
||||||
|
df = pd.read_csv(excel_file_name_preprocessed)
|
||||||
|
df_id = request.session.get("cfrow_id")
|
||||||
|
model = request.session.get("model")
|
||||||
|
row = df.iloc[[int(df_id)]]
|
||||||
|
counterfactuals(row, model, excel_file_name_preprocessed)
|
||||||
|
|
||||||
|
# get coordinates of the clicked point (saved during the actual click)
|
||||||
|
clicked_point = request.session.get("clicked_point")
|
||||||
|
clicked_point_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"0": clicked_point[0],
|
||||||
|
"1": clicked_point[1],
|
||||||
|
"diagnosis": row.diagnosis,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
clicked_point_df.reset_index(drop=True)
|
||||||
|
|
||||||
|
# tSNE
|
||||||
|
cf_df = pd.read_csv("counterfactuals.csv")
|
||||||
|
|
||||||
|
# get rows count
|
||||||
|
request.session["cfdf_rows"] = cf_df.index.values.tolist()
|
||||||
|
|
||||||
|
# select a cf randomly for demonstration
|
||||||
|
request.session["cfrow_cf"] = cf_df.iloc[:1].to_html()
|
||||||
|
|
||||||
|
df_merged = pd.concat(
|
||||||
|
[cf_df, df.drop("id", axis=1)], ignore_index=True, axis=0
|
||||||
|
)
|
||||||
|
tsne_cf = TSNE(n_components=2, random_state=0)
|
||||||
|
projections = tsne_cf.fit_transform(
|
||||||
|
df_merged.drop(["diagnosis"], axis=1).values
|
||||||
|
)
|
||||||
|
|
||||||
|
cf_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"0": projections[:3, 0],
|
||||||
|
"1": projections[:3, 1],
|
||||||
|
"diagnosis": cf_df.diagnosis.iloc[:3],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cf_df = pd.concat([cf_df, clicked_point_df], ignore_index=True, axis=0)
|
||||||
|
|
||||||
|
tsne = joblib.load("tsne.sav")
|
||||||
|
cf_s = px.scatter(
|
||||||
|
cf_df,
|
||||||
|
x="0",
|
||||||
|
y="1",
|
||||||
|
color="diagnosis",
|
||||||
|
color_continuous_scale=px.colors.sequential.Rainbow,
|
||||||
|
)
|
||||||
|
|
||||||
|
cf_s.update_traces(
|
||||||
|
marker=dict(
|
||||||
|
size=10,
|
||||||
|
symbol="circle",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tsne.add_trace(cf_s.data[0])
|
||||||
|
pickle.dump(tsne, open("tsne_cfs.sav", "wb"))
|
||||||
|
tsne = tsne.to_html()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if os.path.exists(excel_file_name) == False:
|
if os.path.exists(excel_file_name) == False:
|
||||||
excel_file_name = "dataset.csv"
|
excel_file_name = "dataset.csv"
|
||||||
request.session["excel_file_name"] = excel_file_name
|
request.session["excel_file_name"] = excel_file_name
|
||||||
|
|
||||||
df = pd.read_csv(excel_file_name)
|
df = pd.read_csv(excel_file_name)
|
||||||
fig2 = None
|
|
||||||
|
|
||||||
# just random columns to plot
|
# just random columns to plot
|
||||||
feature1 = df.columns[0]
|
feature1 = df.columns[3]
|
||||||
feature2 = df.columns[1]
|
feature2 = df.columns[2]
|
||||||
request.session["feature1"] = feature1
|
request.session["feature1"] = feature1
|
||||||
request.session["feature2"] = feature2
|
request.session["feature2"] = feature2
|
||||||
fig = stats(
|
fig = stats(
|
||||||
@ -237,10 +418,11 @@ def home(request):
|
|||||||
request.session["fig2"] = fig2
|
request.session["fig2"] = fig2
|
||||||
request.session["pca"] = pca
|
request.session["pca"] = pca
|
||||||
request.session["tsne"] = tsne
|
request.session["tsne"] = tsne
|
||||||
|
request.session["clas_report"] = clas_report
|
||||||
|
|
||||||
data_to_display = df[:5].to_html()
|
data_to_display = df[:10].to_html(index=False)
|
||||||
request.session["data_to_display"] = data_to_display
|
request.session["data_to_display"] = data_to_display
|
||||||
labels = df.columns
|
labels = df.columns[2:]
|
||||||
context = {
|
context = {
|
||||||
"data_to_display": data_to_display,
|
"data_to_display": data_to_display,
|
||||||
"excel_file": excel_file_name,
|
"excel_file": excel_file_name,
|
||||||
@ -252,6 +434,9 @@ def home(request):
|
|||||||
"clas_report": clas_report,
|
"clas_report": clas_report,
|
||||||
"pca": pca,
|
"pca": pca,
|
||||||
"tsne": tsne,
|
"tsne": tsne,
|
||||||
|
"cfrow_og": cfrow_og,
|
||||||
|
"cfrow_cf": cfrow_cf,
|
||||||
|
"cfdf_rows": cfdf_rows,
|
||||||
}
|
}
|
||||||
|
|
||||||
return render(request, "base/home.html", context)
|
return render(request, "base/home.html", context)
|
||||||
@ -278,6 +463,9 @@ def stats(name, feature1, feature2):
|
|||||||
else:
|
else:
|
||||||
# they both are categorical so do scatter
|
# they both are categorical so do scatter
|
||||||
fig = px.histogram(df, x=feature1, color=feature2)
|
fig = px.histogram(df, x=feature1, color=feature2)
|
||||||
|
|
||||||
|
fig.update_layout(clickmode="event+select")
|
||||||
|
|
||||||
fig = fig.to_html(full_html=False)
|
fig = fig.to_html(full_html=False)
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
@ -285,13 +473,14 @@ def stats(name, feature1, feature2):
|
|||||||
def preprocess(data, value_list, name):
|
def preprocess(data, value_list, name):
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
ids = data["id"]
|
||||||
|
data = data.drop(["id"], axis=1)
|
||||||
for type in value_list:
|
for type in value_list:
|
||||||
if type == "std":
|
if type == "std":
|
||||||
# define standard scaler
|
# define standard scaler
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
y = data["diagnosis"]
|
y = data["diagnosis"]
|
||||||
|
|
||||||
if is_numeric_dtype(data["diagnosis"]):
|
|
||||||
# if class column is numeric do not
|
# if class column is numeric do not
|
||||||
# apply preprocessing
|
# apply preprocessing
|
||||||
data = data.drop(["diagnosis"], axis=1)
|
data = data.drop(["diagnosis"], axis=1)
|
||||||
@ -299,13 +488,15 @@ def preprocess(data, value_list, name):
|
|||||||
# transform data
|
# transform data
|
||||||
cols = data.select_dtypes(np.number).columns
|
cols = data.select_dtypes(np.number).columns
|
||||||
data[cols] = scaler.fit_transform(data[cols])
|
data[cols] = scaler.fit_transform(data[cols])
|
||||||
y = y.to_frame()
|
data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
|
||||||
data = data.join(y)
|
|
||||||
|
|
||||||
if type == "onehot":
|
if type == "onehot":
|
||||||
data = pd.get_dummies(data)
|
data = pd.get_dummies(data)
|
||||||
|
|
||||||
if type == "imp":
|
if type == "imp":
|
||||||
|
y = data["diagnosis"]
|
||||||
|
data = data.drop(["diagnosis"], axis=1)
|
||||||
|
|
||||||
data_numeric = data.select_dtypes(exclude=["object"])
|
data_numeric = data.select_dtypes(exclude=["object"])
|
||||||
data_categorical = data.select_dtypes(exclude=["number"])
|
data_categorical = data.select_dtypes(exclude=["number"])
|
||||||
imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent")
|
imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent")
|
||||||
@ -318,17 +509,20 @@ def preprocess(data, value_list, name):
|
|||||||
[data_numeric, data_categorical], axis=1, ignore_index=False
|
[data_numeric, data_categorical], axis=1, ignore_index=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False)
|
||||||
|
new = pd.concat([ids.to_frame(), data], axis=1, ignore_index=False)
|
||||||
os.remove(name)
|
os.remove(name)
|
||||||
data.to_csv(name, index=False)
|
new.to_csv(name, index=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def training(name, type, test_size=0.7):
|
def training(name, type, test_size=0.7):
|
||||||
data = pd.read_csv(name)
|
data = pd.read_csv(name)
|
||||||
|
X = data.drop("id", axis=1)
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
y = data["diagnosis"]
|
y = X["diagnosis"]
|
||||||
X = data.drop("diagnosis", axis=1)
|
X = X.drop("diagnosis", axis=1)
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42
|
X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42
|
||||||
)
|
)
|
||||||
@ -377,6 +571,16 @@ def training(name, type, test_size=0.7):
|
|||||||
importance = svc.coef_[0]
|
importance = svc.coef_[0]
|
||||||
model = svc
|
model = svc
|
||||||
|
|
||||||
|
if "rf" == type:
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
rf = RandomForestClassifier()
|
||||||
|
rf.fit(X_train, y_train)
|
||||||
|
y_pred = rf.predict(X_test)
|
||||||
|
filename = "rf.sav"
|
||||||
|
importance = rf.feature_importances_
|
||||||
|
model = rf
|
||||||
|
|
||||||
clas_report = classification_report(y_test, y_pred, output_dict=True)
|
clas_report = classification_report(y_test, y_pred, output_dict=True)
|
||||||
clas_report = pd.DataFrame(clas_report).transpose()
|
clas_report = pd.DataFrame(clas_report).transpose()
|
||||||
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
|
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
|
||||||
@ -423,6 +627,13 @@ def testing(name, type):
|
|||||||
importance = svc.coef_[0]
|
importance = svc.coef_[0]
|
||||||
model = svc
|
model = svc
|
||||||
|
|
||||||
|
if "rf" == type:
|
||||||
|
filename = "rf.sav"
|
||||||
|
rf = joblib.load(filename)
|
||||||
|
y_pred = rf.predict(X_test)
|
||||||
|
importance = rf.feature_importances_
|
||||||
|
model = rf
|
||||||
|
|
||||||
clas_report = classification_report(y_test, y_pred, output_dict=True)
|
clas_report = classification_report(y_test, y_pred, output_dict=True)
|
||||||
clas_report = pd.DataFrame(clas_report).transpose()
|
clas_report = pd.DataFrame(clas_report).transpose()
|
||||||
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
|
clas_report = clas_report.sort_values(by=["f1-score"], ascending=False)
|
||||||
@ -433,3 +644,73 @@ def testing(name, type):
|
|||||||
"clas_report": clas_report,
|
"clas_report": clas_report,
|
||||||
}
|
}
|
||||||
return con
|
return con
|
||||||
|
|
||||||
|
|
||||||
|
def counterfactuals(row, model, preprocessed_file):
|
||||||
|
df = pd.read_csv(preprocessed_file)
|
||||||
|
df = df.drop("id", axis=1)
|
||||||
|
row = row.drop(["diagnosis", "id"], axis=1)
|
||||||
|
if "logit" == model:
|
||||||
|
filename = "lg.sav"
|
||||||
|
clf = joblib.load(filename)
|
||||||
|
model = clf
|
||||||
|
|
||||||
|
if "xgb" == model:
|
||||||
|
filename = "xgb.sav"
|
||||||
|
xgb = joblib.load(filename)
|
||||||
|
model = xgb
|
||||||
|
|
||||||
|
if "dt" == model:
|
||||||
|
filename = "dt.sav"
|
||||||
|
dt = joblib.load(filename)
|
||||||
|
model = dt
|
||||||
|
|
||||||
|
if "svm" == model:
|
||||||
|
filename = "svc.sav"
|
||||||
|
svc = joblib.load(filename)
|
||||||
|
model = svc
|
||||||
|
|
||||||
|
if "rf" == model:
|
||||||
|
filename = "rf.sav"
|
||||||
|
rf = joblib.load(filename)
|
||||||
|
model = rf
|
||||||
|
|
||||||
|
d = dice_ml.Data(
|
||||||
|
dataframe=df,
|
||||||
|
continuous_features=[
|
||||||
|
"radius_mean",
|
||||||
|
"texture_mean",
|
||||||
|
"smoothness_mean",
|
||||||
|
"compactness_mean",
|
||||||
|
"concavity_mean",
|
||||||
|
"symmetry_mean",
|
||||||
|
"fractal_dimension_mean",
|
||||||
|
"radius_se",
|
||||||
|
"texture_se",
|
||||||
|
"smoothness_se",
|
||||||
|
"compactness_se",
|
||||||
|
"concavity_se",
|
||||||
|
"concave points_se",
|
||||||
|
"symmetry_se",
|
||||||
|
"fractal_dimension_se",
|
||||||
|
"compactness_worst",
|
||||||
|
"concavity_worst",
|
||||||
|
"concave points_worst",
|
||||||
|
"fractal_dimension_worst",
|
||||||
|
],
|
||||||
|
outcome_name="diagnosis",
|
||||||
|
)
|
||||||
|
|
||||||
|
m = dice_ml.Model(model=model, backend="sklearn")
|
||||||
|
exp = dice_ml.Dice(d, m)
|
||||||
|
dice_exp = exp.generate_counterfactuals(
|
||||||
|
row, # The data from the 1st row of our dataframe
|
||||||
|
total_CFs=3, # Total number of Counterfactual Examples we want to print out. There can be multiple.
|
||||||
|
desired_class="opposite", # We want to convert the quality to the opposite one.
|
||||||
|
)
|
||||||
|
|
||||||
|
# save cfs
|
||||||
|
dice_exp.cf_examples_list[0].final_cfs_df.to_csv(
|
||||||
|
path_or_buf="counterfactuals.csv", index=False
|
||||||
|
)
|
||||||
|
return dice_exp._cf_examples_list
|
||||||
|
BIN
db.sqlite3
BIN
db.sqlite3
Binary file not shown.
Binary file not shown.
3
extremum/routing.py
Normal file
3
extremum/routing.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from channels.routing import ProtocolTypeRouter
|
||||||
|
|
||||||
|
application = ProtocolTypeRouter({})
|
@ -40,6 +40,9 @@ INSTALLED_APPS = [
|
|||||||
"django.contrib.staticfiles",
|
"django.contrib.staticfiles",
|
||||||
"base.apps.BaseConfig",
|
"base.apps.BaseConfig",
|
||||||
"bootstrap5",
|
"bootstrap5",
|
||||||
|
"django_plotly_dash.apps.DjangoPlotlyDashConfig",
|
||||||
|
"channels",
|
||||||
|
"channels_redis",
|
||||||
]
|
]
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
@ -105,6 +108,7 @@ AUTH_PASSWORD_VALIDATORS = [
|
|||||||
|
|
||||||
# Internationalization
|
# Internationalization
|
||||||
# https://docs.djangoproject.com/en/5.0/topics/i18n/
|
# https://docs.djangoproject.com/en/5.0/topics/i18n/
|
||||||
|
X_FRAME_OPTIONS = 'SAMEORIGIN'
|
||||||
|
|
||||||
LANGUAGE_CODE = "en-us"
|
LANGUAGE_CODE = "en-us"
|
||||||
|
|
||||||
@ -114,6 +118,32 @@ USE_I18N = True
|
|||||||
|
|
||||||
USE_TZ = True
|
USE_TZ = True
|
||||||
|
|
||||||
|
CRISPY_TEMPLATE = "bootsrap5"
|
||||||
|
|
||||||
|
ASGI_APPLICATION = "extremum.routing.applications"
|
||||||
|
|
||||||
|
CHANNEL_LAYERS = {
|
||||||
|
'default': {
|
||||||
|
'BACKEND': 'channels_redis.core.RedisChannelLayer',
|
||||||
|
'CONFIG': {
|
||||||
|
'hosts': [('127.0.0.1', 6379),],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
SSTATICFILRES_FINDERS = [
|
||||||
|
"django.contrib.staticfiles.finders.FileSystemFinder",
|
||||||
|
"django.contrib.staticfiles.finders.AppDirectoriesFinder",
|
||||||
|
"django_plotly_dash.finder.DashAssetFinder",
|
||||||
|
"dajango_plotly_dash.finders.DashComponentsFinder",
|
||||||
|
]
|
||||||
|
|
||||||
|
PLOTLY_COMPONENTS = [
|
||||||
|
"dash_core_components",
|
||||||
|
"dash_html_components",
|
||||||
|
"dash_renderer",
|
||||||
|
"dpd_components",
|
||||||
|
]
|
||||||
|
|
||||||
# Static files (CSS, JavaScript, Images)
|
# Static files (CSS, JavaScript, Images)
|
||||||
# https://docs.djangoproject.com/en/5.0/howto/static-files/
|
# https://docs.djangoproject.com/en/5.0/howto/static-files/
|
||||||
|
@ -11,7 +11,10 @@
|
|||||||
<title>EXTREMUM</title>
|
<title>EXTREMUM</title>
|
||||||
<meta name="viewport" content="'width=device-width, initial-scale=1" />
|
<meta name="viewport" content="'width=device-width, initial-scale=1" />
|
||||||
<link rel="stylesheet" href="{% static 'css/style.css' %}">
|
<link rel="stylesheet" href="{% static 'css/style.css' %}">
|
||||||
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
|
<script src="https://cdn.plot.ly/plotly-latest.min.js" charset="utf-8"></script>
|
||||||
|
<script src="https://code.jquery.com/jquery-3.7.1.js" integrity="sha256-eKhayi8LEQwp4NKxN+CfCh+3qOVUtJn3QNZ0TciWLP4=" crossorigin="anonymous"></script>
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/svg.js/3.2.0/svg.min.js" charset="utf-8"></script>
|
||||||
|
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css">
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css">
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
@ -29,12 +32,13 @@
|
|||||||
<br>
|
<br>
|
||||||
<br>
|
<br>
|
||||||
<body onscroll="setScroll()" onload="setScreen()">
|
<body onscroll="setScroll()" onload="setScreen()">
|
||||||
|
|
||||||
{% block content %}
|
{% block content %}
|
||||||
{% endblock content %}
|
{% endblock content %}
|
||||||
|
<script src="{% static 'js/click_on_graph.js' %}"></script>
|
||||||
<script src="{% static 'js/hide_seek.js' %}"></script>
|
<script src="{% static 'js/hide_seek.js' %}"></script>
|
||||||
<script src="{% static 'js/slider.js' %}"></script>
|
<script src="{% static 'js/slider.js' %}"></script>
|
||||||
<script src="{% static 'js/keep_scroll_on_load.js' %}"></script>
|
<script src="{% static 'js/keep_scroll_on_load.js' %}"></script>
|
||||||
|
<script src="{% static 'js/counterfactuals.js' %}"></script>
|
||||||
|
<script src="{% static 'js/click_reset.js' %}"></script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user