diff --git a/base/__pycache__/urls.cpython-310.pyc b/base/__pycache__/urls.cpython-310.pyc index b71d11b40..988c25f98 100644 Binary files a/base/__pycache__/urls.cpython-310.pyc and b/base/__pycache__/urls.cpython-310.pyc differ diff --git a/base/__pycache__/views.cpython-310.pyc b/base/__pycache__/views.cpython-310.pyc index e63eb99ed..643c89d17 100644 Binary files a/base/__pycache__/views.cpython-310.pyc and b/base/__pycache__/views.cpython-310.pyc differ diff --git a/base/dash_apps/finished_apps/__pycache__/simpleexample.cpython-310.pyc b/base/dash_apps/finished_apps/__pycache__/simpleexample.cpython-310.pyc new file mode 100644 index 000000000..ddbcde14d Binary files /dev/null and b/base/dash_apps/finished_apps/__pycache__/simpleexample.cpython-310.pyc differ diff --git a/base/dash_apps/finished_apps/tsne.py b/base/dash_apps/finished_apps/tsne.py new file mode 100644 index 000000000..2140ab933 --- /dev/null +++ b/base/dash_apps/finished_apps/tsne.py @@ -0,0 +1,54 @@ +from base import views +from dash import dcc +from dash import html +from dash.dependencies import Input, Output +from django_plotly_dash import DjangoDash +import plotly.express as px +import pandas as pd +from sklearn.manifold import TSNE +import json + +external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"] +app = DjangoDash("Tsne", external_stylesheets=external_stylesheets) +excel_file_name_preprocessed = "breast-cancer_preprocessed.csv" +preprocess_df = pd.read_csv(excel_file_name_preprocessed) +preprocess_df.drop(["id"], axis=1, inplace=True) +features = preprocess_df.loc[:, :"compactness_se"] + +tsne = TSNE(n_components=2, random_state=0) +projections = tsne.fit_transform(features) +tsne = px.scatter( + projections, + x=0, + y=1, + color=preprocess_df.diagnosis.astype(str), + labels={"color": "diagnosis"}, +) +tsne.update_layout(clickmode="event+select") + +app.layout = html.Div( + [ + dcc.Graph(id="scatter-plot", config={"displayModeBar": False}, figure=tsne), + html.Div( + className="row", + children=[ + html.Div( + [ + dcc.Input( + id="click-data", + type="hidden", + className="three columns", + ) + ] + ) + ], + ), + ] +) + + +@app.callback(Output("click-data", "children"), Input("scatter-plot", "clickData")) +def display_click_data(clickData, request): + # dictionary of list of dictionary ??? + views.counterfactuals(clickData, request) + return json.dumps(clickData, indent=2) diff --git a/base/migrations/0003_delete_document.py b/base/migrations/0003_delete_document.py new file mode 100644 index 000000000..7f2384742 --- /dev/null +++ b/base/migrations/0003_delete_document.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.13 on 2024-06-17 08:51 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("base", "0002_document_delete_upload"), + ] + + operations = [ + migrations.DeleteModel( + name="Document", + ), + ] diff --git a/base/migrations/__pycache__/0003_delete_document.cpython-310.pyc b/base/migrations/__pycache__/0003_delete_document.cpython-310.pyc new file mode 100644 index 000000000..611e933f1 Binary files /dev/null and b/base/migrations/__pycache__/0003_delete_document.cpython-310.pyc differ diff --git a/base/static/css/style.css b/base/static/css/style.css index 6b04bce02..503bb7e07 100644 --- a/base/static/css/style.css +++ b/base/static/css/style.css @@ -288,12 +288,13 @@ nav input { height: min-content; max-height: 40%; max-width: 35%; + border: 1px solid #bbb; + padding: 5px; } .dataframe { font-size: 9pt; font-family: Arial; - border-collapse: collapse; font-size: 0.9em; } @@ -306,6 +307,13 @@ nav input { .dataframe td { padding: 12px 15px; text-align: left; + border: black; + border-collapse: separate; +} + +.dataframe .clickedrow th, +.dataframe .clickedrow td { + background-color: #c6bdf8; } .dataframe tbody tr { @@ -319,31 +327,36 @@ nav input { .dataframe tbody tr:last-of-type { border-bottom: 2px solid #009879; } + +.dataframe tbody tr:hover { + background-color: #e8e5f9; +} + /* plotly toolbar */ .modebar { display: none !important; } - /* button */ /* CSS */ .button-6 { align-items: center; - background-color: #FFFFFF; + background-color: #ffffff; border: 1px solid rgba(0, 0, 0, 0.1); - border-radius: .25rem; + border-radius: 0.25rem; box-shadow: rgba(0, 0, 0, 0.02) 0 1px 3px 0; box-sizing: border-box; color: rgba(0, 0, 0, 0.85); cursor: pointer; display: inline-flex; - font-family: system-ui,-apple-system,system-ui,"Helvetica Neue",Helvetica,Arial,sans-serif; + font-family: system-ui, -apple-system, system-ui, "Helvetica Neue", Helvetica, + Arial, sans-serif; font-size: 16px; font-weight: 600; justify-content: center; line-height: 1.25; margin: 0; min-height: 3rem; - padding: calc(.875rem - 1px) calc(1.5rem - 1px); + padding: calc(0.875rem - 1px) calc(1.5rem - 1px); position: relative; text-decoration: none; transition: all 250ms; @@ -366,7 +379,7 @@ nav input { } .button-6:active { - background-color: #F0F0F1; + background-color: #f0f0f1; border-color: rgba(0, 0, 0, 0.15); box-shadow: rgba(0, 0, 0, 0.06) 0 2px 4px; color: rgba(0, 0, 0, 0.65); @@ -379,7 +392,7 @@ nav input { display: flex; flex-wrap: wrap; border-radius: 0.5rem; - background-color: #EEE; + background-color: #eee; box-sizing: border-box; box-shadow: 0 0 0px 1px rgba(0, 0, 0, 0.06); padding: 0.25rem; @@ -403,12 +416,32 @@ nav input { justify-content: center; border-radius: 0.5rem; border: none; - padding: .5rem 0; + padding: 0.5rem 0; color: rgba(51, 65, 85, 1); - transition: all .15s ease-in-out; + transition: all 0.15s ease-in-out; } .radio-inputs .radio input:checked + .name { background-color: #fff; font-weight: 600; -} \ No newline at end of file +} + +.loader { + width: 48px; + height: 48px; + border: 5px solid #fff; + border-bottom-color: #ff3d00; + border-radius: 50%; + display: inline-block; + box-sizing: border-box; + animation: rotation 1s linear infinite; +} + +@keyframes rotation { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/base/static/js/click_on_graph.js b/base/static/js/click_on_graph.js new file mode 100644 index 000000000..fa3de051c --- /dev/null +++ b/base/static/js/click_on_graph.js @@ -0,0 +1,108 @@ +$(document).ready(function () { + // When click on graph (tsne) + // iterate through the elements and find the one with opacity 1 + // the rest have opacity 0.2 when a button is clicked... + + // Case 1: They are all opacity 1! + + // In that case we cannot check all of them + // it would be inneficient, + // though if we just find one more we can determine + // that all the rest are also opacity: 1 and thus + // there was no click on a specific point + + // Case 2: There was actually a click on a point! + + // In that case we need to find that point. Poltly + // does not save all the points on one <g> element + // Instead on the particular example it would use 2 + // <g> elements with multiple <points> in them. + // We iterate through all the collections and all the points + // until we find the point element with opacity 1. + // That way feels inefficient but in reality rarely will + // the last ever point of the last collection be selected + // and even then the complexity is just O(n). Not great + // but not bad either. It would help to be able to determine + // if a point is in collection i or j but I am not there + // yet. + + document.getElementById('tsne').addEventListener('click', function () { + + // remove original and counterfactual columns + if (document.getElementById("og_cf_row") && document.getElementById("og_cf_headers")) { + document.getElementById("og_cf_row").style.display = "none"; + document.getElementById("og_cf_headers").style.display = "none"; + } + + var tsne = document.getElementById('tsne'); + var points_array = tsne.getElementsByClassName("points") + var counter_break = 0; + var final_position = 0; + var j; + var i; + + for (i = 0; i < points_array.length; i++) { + var points = points_array[i] + var children = points.children + if (getComputedStyle(children[0]).opacity == 1 && getComputedStyle(children[1]).opacity == 1) { + break; + } + console.log(children.length) + + for (j = 0; j < children.length; j++) { + var st = children[j] + opacity = getComputedStyle(st).opacity + if (opacity == 1) { + console.log("i: ", i) + console.log("J: ", j) + console.log("children.length: ", children.length) + counter_break++; + final_position += j; + console.log("final_position: ", final_position) + break; + } + } + if (counter_break) { + break; + } + final_position += j; + } + + if (counter_break == 0) { + $("#cfrow").remove(); + document.getElementById("cftable").style.display = "none"; + document.getElementById("cfbtn").style.display = "none"; + } else if (counter_break == 1) { + var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val(); + $.ajax({ + method: 'POST', + url: '', + headers: { 'X-CSRFToken': csrftoken }, + data: { 'row': final_position, 'action': "click_graph" }, + success: function (ret) { + + row = JSON.parse(ret) + row = row["row"] + + var tb = document.createElement('table'); + tb.innerHTML = row.trim(); + + tb.setAttribute("id", "cfrow"); + tb.setAttribute("class", "dataframe") + if (document.getElementById("cfrow")) { + $("#cfrow").remove(); + document.getElementById("cftable").style.display = "none"; + } + $("#cftable").append(tb); + document.getElementById("cftable").style.display = "block"; + document.getElementById("cfbtn").style.display = "block"; + }, + error: function (row) { + console.log("it didnt work"); + } + }); + } else if (counter_break > 1) { + console.log("All opacity 1") + } + }); +}); diff --git a/base/static/js/click_on_row.js b/base/static/js/click_on_row.js new file mode 100644 index 000000000..648aa2125 --- /dev/null +++ b/base/static/js/click_on_row.js @@ -0,0 +1,30 @@ +$(document).ready(function () { + //Highlight clicked row + $('.dataframe td').on('click', function () { + + // Remove previous highlight class + $(this).closest('.dataframe').find('tr.clickedrow').removeClass('clickedrow'); + // add highlight to the parent tr of the clicked td + $(this).closest('tr').addClass("clickedrow"); + + // fetch row + var $row = this.closest('tr') + var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val(); + $.ajax({ + method: 'POST', + url: '', + headers: { 'X-CSRFToken': csrftoken }, + data: { 'row': $row.innerHTML.trim().split("\n").map(line => line.trim()).join("\n") }, + success: function (id) { + var myPlot = document.getElementById('tsne'); + var paths = myPlot.querySelector('.points'); + var path = paths.children[id]; + path.setAttribute("style", "opacity: 1; stroke-width: 0px; fill: green; fill-opacity: 1; scale: 2.0;"); + console.log(path) + }, + error: function (id) { + console.log("it didnt work"); + } + }); + }); +}); diff --git a/base/static/js/click_reset.js b/base/static/js/click_reset.js new file mode 100644 index 000000000..e5c1ccc0b --- /dev/null +++ b/base/static/js/click_reset.js @@ -0,0 +1,27 @@ +$(document).ready(function () { + //Highlight clicked row + document.getElementById('reset_div').addEventListener('click', function () { + // on click reset the graph + // if reset button exists, hide it + document.getElementById("reset_div").style.display = "none"; + document.getElementById("loader_cf").style.display = "block"; + + var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val(); + $.ajax({ + method: 'POST', + url: '', + headers: { 'X-CSRFToken': csrftoken }, + data: {'action': "reset_graph" }, + success: function (ret) { + document.getElementById("loader_cf").style.display = "none"; + ret = JSON.parse(ret) + fig = ret["fig"] + document.getElementById("tsne").innerHTML = ""; + $("#tsne").append(fig) + document.getElementById("reset_div").style.display = "none"; + }, + error: function (ret) { + } + }); + }); +}); diff --git a/base/static/js/counterfactuals.js b/base/static/js/counterfactuals.js new file mode 100644 index 000000000..06612fd5e --- /dev/null +++ b/base/static/js/counterfactuals.js @@ -0,0 +1,62 @@ +$(document).ready(function () { + var mySelect = document.getElementById('cfrow_id'); + mySelect.onchange = (event) => { + var rowId = event.target.value; + // inputText is the row id of the counter factual + // Send ajax request to backend and inquire for the + // respective counterfactual. When it is acquired + // replace current counterfactual view with that... + var csrftoken = jQuery("[name=csrfmiddlewaretoken]").val(); + document.getElementById("loader_cf").style.display = "block"; + + // if reset button exists, hide it + if (document.getElementById("reset_div")) + document.getElementById("reset_div").style.display = "none"; + + $("#show_image").show() + $.ajax({ + method: 'POST', + url: '', + headers: { 'X-CSRFToken': csrftoken }, + data: { 'row': rowId, 'action': "counterfactual_select" }, + success: function (ret) { + // stop loader + document.getElementById("loader_cf").style.display = "none"; + + // add <input type="reset" value="Reset"> + if (!document.getElementById("reset")) { + var reset = document.createElement('input') + reset.setAttribute("id", "reset") + reset.setAttribute("type", "reset") + reset.setAttribute("value", "Reset") + $("#reset_div").append(reset) + } + // if reset button has been created once just display it + document.getElementById("reset_div").style.display = "block"; + + + var tb = document.createElement('table'); + ret = JSON.parse(ret) + row = ret["row"] + tb.innerHTML = row.trim(); + tb.setAttribute("id", "counterfactual_selected"); + tb.setAttribute("class", "dataframe") + + if (document.getElementById("counterfactual_selected")) { + $("#counterfactual_selected").remove(); + } + $("#counterfactual").append(tb); + document.getElementById("counterfactual_selected").style.display = "block"; + + fig = ret["fig"] + document.getElementById("tsne").innerHTML = ""; + $("#tsne").append(fig) + + + }, + error: function (row) { + console.log("it didnt work"); + } + }); + } +}); diff --git a/base/static/js/keep_scroll_on_load.js b/base/static/js/keep_scroll_on_load.js index 25f43badc..c3e7915bf 100644 --- a/base/static/js/keep_scroll_on_load.js +++ b/base/static/js/keep_scroll_on_load.js @@ -3,7 +3,7 @@ function setScreen() { window.scrollTo(0, yScreen); } function setScroll() { - var yScroll = window.pageYOffset; + var yScroll = window.scrollY; localStorage.setItem("yPos", yScroll); } function clearScreen() { diff --git a/base/templates/base/home.html b/base/templates/base/home.html index 27c2f23dc..cb0c772da 100644 --- a/base/templates/base/home.html +++ b/base/templates/base/home.html @@ -2,6 +2,26 @@ {% block content %} {% load static %} +<!-- <div class="container-fluid"> + <div class="row"> + <div class="col-sm d-flex justify-content-center" > + <div class="dataset"> + <div class="form-check"> + <input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault1"> + <label class="form-check-label" for="flexRadioDefault1"> + Breast Cancer Dataset + </label> + </div> + <div class="form-check"> + <input class="form-check-input" type="radio" name="flexRadioDefault" id="flexRadioDefault2"> + <label class="form-check-label" for="flexRadioDefault2"> + Other + </label> + </div> + </div> + </div> + </div> +</div> --> <div class="container-fluid"> <div class="row"> <div class="col-sm d-flex justify-content-center"> @@ -68,8 +88,10 @@ </div> </div> </div> + <br> <br> + <div class="container-fluid"> <div class="row"> <section> @@ -112,25 +134,20 @@ <button type="submit" class="button-6" role="button" name="preprocess" form="preprocess-form">Go!</button> </div> </div> + + <br> + <br> <div class="row"> - {% if pca %} - <div class="col-sm d-flex justify-content-center"> - {{ pca|safe }} - </div> - {% endif %} - - {% if tsne %} - <div class="col-sm d-flex justify-content-center"> - {{ tsne|safe }} - </div> - {% endif %} + {% if pca %} + <div class="col d-flex justify-content-center" id="pca"> + {{ pca|safe }} + </div> + {% endif %} </div> - - <br> - <br> </div> + <div class="container-fluid"> <form action="{% url 'home' %}" method="POST" id="traintest-form"> {% csrf_token %} @@ -147,7 +164,7 @@ <option type="submit" value="logit" selected>Logistic Regression</option> <option type="submit" value="xgb">XGBoost</option> <option type="submit" value="dt">Decision Tree</option> - <option type="submit" value="rt">Random Forest</option> + <option type="submit" value="rf">Random Forest</option> </select> </label> </div> @@ -200,6 +217,7 @@ {{ fig2|safe }} </div> {% endif %} + {% if clas_report %} <div class="col-sm d-flex justify-content-center"> <div class="scrollit"> @@ -209,4 +227,97 @@ {% endif %} </div> </div> + +<br> +<br> + +<div class="container-fluid"> + <div class="row"> + <section> + <label style="display:flex; + flex-direction:column; + align-items: center;"> + <h2> + <i class="fas fa-cog"></i> Counterfactuals + </h2> + <h5> + Pick a point on the graph + </h5> + + </label> + </section> + <div class="col-sm d-flex justify-content-center"> + </div> + </div> + + {% if tsne %} + <div class="row"> + <div class="col-md justify-content-center" id="tsne"> + {{ tsne|safe }} + </div> + </div> + {% endif %} + + <div class="row"> + <form action="{% url 'home' %}" method="POST" id="cf-form"> + {% csrf_token %} + <div class="col-md d-flex justify-content-center"> + <div class="scrollit" class="cftable" name="cfrow" id="cftable" style="border: 1px solid #bbb; padding: 5px; display: none; max-width: 65%;" > + </div> + </div> + </form> + </div> + <div class="row"> + <div class="col-sm d-flex justify-content-center"> + <button id="cfbtn" type="submit" class="button-6" role="button" name="cf" form="cf-form" style="display:none;">Run Counterfactuals!</button> + </div> + </div> + +{% if cfrow_og and cfrow_cf %} + <div class="row" id="og_cf_headers"> + <div class="col d-flex justify-content-center"> + <h2> + Original data + </h2> + </div> + <span style="display: none; width: 48px; height: 48px;" id="reset_div"> + </span> + <span class="loader" id="loader_cf" style="display: none;"> + </span> + <div class="col d-flex justify-content-center" style="select:invalid { color: gray; }"> + <label style="display:flex; + flex-direction:column; + align-items: center;"> + <select id="cfrow_id" name="cfrow_id" style="scale:1.2;"> + <option value="" disabled selected hidden>Pick counterfactual...</option> + {% for row in cfdf_rows %} + <option value={{row}}>Counterfactual {{row}}</option> + {% endfor %} + </select> + </label> + </div> + </div> + + <div class="row" id="og_cf_row"> + <div class="col-md d-flex justify-content-center" > + <div class="scrollit" id="original_data"> + {{ cfrow_og|safe}} + </div> + </div> + <div class="col-md d-flex justify-content-center" > + <div class="scrollit" id="counterfactual"> + <div id="counterfactual_selected"> + {{ cfrow_cf|safe }} + </div> + </div> + </div> + </div> +{% endif %} +<br> +<br> +</div> +<div class="container-fluid"> +</div> +<br> +<br> {% endblock content%} diff --git a/base/urls.py b/base/urls.py index f70717da1..9ea45c9e2 100644 --- a/base/urls.py +++ b/base/urls.py @@ -1,10 +1,6 @@ from django.urls import path, include from . import views -from . import models - urlpatterns = [ path('', views.home, name="home"), - path('preprocess', views.preprocess, name="preprocess"), - path('stats', views.stats, name="stats"), ] \ No newline at end of file diff --git a/base/utils.py b/base/utils.py index 286f0ce78..e69de29bb 100644 --- a/base/utils.py +++ b/base/utils.py @@ -1,5 +0,0 @@ -def stats(feature1, feature2, df): - import plotly.express as px - fig = px.scatter(df, x=feature1, y=feature2, color='Churn') - fig = fig.to_html(full_html=False) - return fig diff --git a/base/views.py b/base/views.py index 7b4589ffd..5664b0b84 100644 --- a/base/views.py +++ b/base/views.py @@ -1,27 +1,26 @@ -from django.shortcuts import render, redirect +from django.shortcuts import render, HttpResponse import pandas as pd -from django.core.files.storage import FileSystemStorage +import json +import IPython import pickle, os from sklearn.impute import SimpleImputer import numpy as np -from pandas.api.types import is_string_dtype from pandas.api.types import is_numeric_dtype -from sklearn.metrics import accuracy_score, classification_report +from sklearn.metrics import classification_report import plotly.express as px -import plotly.graph_objects as go from sklearn.preprocessing import LabelEncoder import joblib from sklearn.decomposition import PCA from sklearn.manifold import TSNE +import dice_ml +import plotly.graph_objects as go -clas_report = None FILE_NAME = "dataset.csv" PROCESS_FILE_NAME = "dataset_preprocessed.csv" def home(request): - global clas_report # request.session.flush() if "fig" in request.session: fig = request.session.get("fig") @@ -43,6 +42,26 @@ def home(request): else: tsne = None + if "cfrow_og" in request.session: + cfrow_og = request.session.get("cfrow_og") + else: + cfrow_og = None + + if "cfrow_cf" in request.session: + cfrow_cf = request.session.get("cfrow_cf") + else: + cfrow_cf = None + + if "cfdf_rows" in request.session: + cfdf_rows = request.session.get("cfdf_rows") + else: + cfdf_rows = None + + if "clas_report" in request.session: + clas_report = request.session.get("clas_report") + else: + clas_report = None + if "excel_file_name" in request.session: excel_file_name = request.session.get("excel_file_name") else: @@ -60,15 +79,91 @@ def home(request): excel_file_name_preprocessed = PROCESS_FILE_NAME request.session["excel_file_name_preprocessed"] = excel_file_name_preprocessed + # ajax request condition + if request.headers.get("X-Requested-With") == "XMLHttpRequest": + if request.POST.get("action") == "click_graph": + # ajax request for graph click + # given the ide of the clicked point (through ajax request) + # locate the respective dataframe row + df = pd.read_csv(excel_file_name_preprocessed) + id = request.POST["row"] + row = df.iloc[[int(id)]] + projections = request.session.get("tsne_projection") + + # projections array is a list of pairs with the (x, y) + # coordinates for a point in tsne. These are actual absolute + # coordinates and not SVG. + # Now save the info for use in the future + request.session["clicked_point"] = projections[int(id)] + request.session["cfrow_id"] = request.POST["row"] + request.session["cfrow_og"] = row.to_html() + context = {"row": row.to_html()} + elif request.POST.get("action") == "counterfactual_select": + + # if <select> element is used, and a specific counterfactual + # is inquired to be demonstrated: + df = pd.read_csv("counterfactuals.csv") + id = request.POST["row"] + row = df.iloc[[int(id)]] + + projections = request.session.get("tsne_projection") + fig = joblib.load("tsne_cfs.sav") + + # tsne_cfs is a merged scatter of 2 other scatters + # scatter[0] contains all the points of tsne + # scatter[1] contains all the counterfactual points + # + the point for which counterfactuals were + # computed... + + # blur out all the points of the first scatter + fig.data[0].update( + opacity=0.3, + ) + + # from the second scatter (scatter[1]) locate the clicked_point + # coordinates and keep the index + l = fig.data[1] + clicked_id = -1 + for clicked_id, item in enumerate(list(zip(l.x, l.y))): + if ( + item[0] == request.session.get("clicked_point")[0] + and item[1] == request.session.get("clicked_point")[1] + ): + break + + # id is the index of the respective counter factual + # clicked_id is the index of the original point + # Blur all the points apart from these 2 + fig.data[1].update( + selectedpoints=[id, clicked_id], + unselected=dict( + marker=dict( + opacity=0.3, + ) + ), + ) + context = {"row": row.to_html(), "fig": fig.to_html()} + elif request.POST.get("action") == "reset_graph": + fig = request.session.get("tsne") + context = {"fig": fig} + + return HttpResponse(json.dumps(context)) + df = pd.DataFrame() if request.method == "POST": if "csv" in request.POST: + excel_file = request.FILES["excel_file"] excel_file_name = request.FILES["excel_file"].name fig = None fig2 = None + pca = None + tsne = None + clas_report = None + cfrow_cf = None + cfrow_og = None # here we dont use the name of the file since the # uploaded file is not yet saved @@ -79,11 +174,11 @@ def home(request): # fs.save(excel_file_name, excel_file) df = pd.read_csv(excel_file) - df.drop(["id"], axis=1, inplace=True) + # df.drop(["id"], axis=1, inplace=True) df.to_csv(excel_file_name, index=False) - feature1 = df.columns[0] - feature2 = df.columns[1] + feature1 = df.columns[3] + feature2 = df.columns[2] request.session["feature1"] = feature1 request.session["feature2"] = feature2 @@ -108,7 +203,8 @@ def home(request): mode = request.POST.get("colorRadio") model = request.POST.get("model") test_size = float(request.POST.get("split_input")) - print(test_size, mode, model) + + request.session["model"] = model if mode == "train": if model == "logit": con = training(excel_file_name_preprocessed, "logit", test_size) @@ -122,6 +218,9 @@ def home(request): elif model == "svm": con = training(excel_file_name_preprocessed, "svm", test_size) + elif model == "rf": + con = training(excel_file_name_preprocessed, "rf", test_size) + fig2 = con["fig2"] clas_report = con["clas_report"].to_html() elif mode == "test": @@ -137,6 +236,9 @@ def home(request): elif model == "svm": con = testing(excel_file_name_preprocessed, "svm") + elif model == "rf": + con = training(excel_file_name_preprocessed, "rf", test_size) + fig2 = con["fig2"] clas_report = con["clas_report"].to_html() @@ -146,21 +248,16 @@ def home(request): # if file for preprocessing does not exist create it # also apply basic preprocessing if os.path.exists(excel_file_name_preprocessed) == False: - # generate filename idx = excel_file_name.index(".") excel_file_name_preprocessed = ( excel_file_name[:idx] + "_preprocessed" + excel_file_name[idx:] ) - # save file for preprocessing preprocess_df = pd.read_csv(excel_file_name) - fs = FileSystemStorage() # defaults to MEDIA_ROOT request.session["excel_file_name_preprocessed"] = ( excel_file_name_preprocessed ) - preprocess_df.to_csv(excel_file_name_preprocessed, index=False) - preprocess_df.drop( ["perimeter_mean", "area_mean"], axis=1, inplace=True ) @@ -180,18 +277,21 @@ def home(request): inplace=True, ) - # preprocess_df.drop(["id"], axis=1, inplace=True) le = LabelEncoder() preprocess_df["diagnosis"] = le.fit_transform( preprocess_df["diagnosis"] ) + + preprocess_df.to_csv(excel_file_name_preprocessed, index=False) else: preprocess_df = pd.read_csv(excel_file_name_preprocessed) preprocess(preprocess_df, value_list, excel_file_name_preprocessed) + preprocess_df.drop(["id"], axis=1, inplace=True) + # PCA pca = PCA() - pca.fit(preprocess_df) + pca.fit(preprocess_df.loc[:, "radius_mean":]) exp_var_cumul = np.cumsum(pca.explained_variance_ratio_) pca = px.area( x=range(1, exp_var_cumul.shape[0] + 1), @@ -199,31 +299,112 @@ def home(request): labels={"x": "# Components", "y": "Explained Variance"}, ).to_html() - features = preprocess_df.loc[:, :"compactness_se"] - + # tSNE tsne = TSNE(n_components=2, random_state=0) - projections = tsne.fit_transform(features) + projections = tsne.fit_transform( + preprocess_df.drop(["diagnosis"], axis=1).values + ) + tsne_df = pd.DataFrame( + { + "0": projections[:, 0], + "1": projections[:, 1], + "diagnosis": preprocess_df["diagnosis"], + } + ) tsne = px.scatter( - projections, - x=0, - y=1, - color=preprocess_df.diagnosis, - labels={"color": "diagnosis"}, - ).to_html() + tsne_df, + x="0", + y="1", + color="diagnosis", + color_continuous_scale=px.colors.sequential.Rainbow, + ) + + tsne.update_layout(clickmode="event+select") + # tsne_opacity.update_layout(clickmode="event+select") + pickle.dump(tsne, open("tsne.sav", "wb")) + tsne = tsne.to_html() + request.session["tsne_projection"] = projections.tolist() + elif "cf" in request.POST: + + excel_file_name_preprocessed = request.session.get( + "excel_file_name_preprocessed" + ) + + df = pd.read_csv(excel_file_name_preprocessed) + df_id = request.session.get("cfrow_id") + model = request.session.get("model") + row = df.iloc[[int(df_id)]] + counterfactuals(row, model, excel_file_name_preprocessed) + + # get coordinates of the clicked point (saved during the actual click) + clicked_point = request.session.get("clicked_point") + clicked_point_df = pd.DataFrame( + { + "0": clicked_point[0], + "1": clicked_point[1], + "diagnosis": row.diagnosis, + } + ) + clicked_point_df.reset_index(drop=True) + + # tSNE + cf_df = pd.read_csv("counterfactuals.csv") + + # get rows count + request.session["cfdf_rows"] = cf_df.index.values.tolist() + + # select a cf randomly for demonstration + request.session["cfrow_cf"] = cf_df.iloc[:1].to_html() + + df_merged = pd.concat( + [cf_df, df.drop("id", axis=1)], ignore_index=True, axis=0 + ) + tsne_cf = TSNE(n_components=2, random_state=0) + projections = tsne_cf.fit_transform( + df_merged.drop(["diagnosis"], axis=1).values + ) + + cf_df = pd.DataFrame( + { + "0": projections[:3, 0], + "1": projections[:3, 1], + "diagnosis": cf_df.diagnosis.iloc[:3], + } + ) + + cf_df = pd.concat([cf_df, clicked_point_df], ignore_index=True, axis=0) + + tsne = joblib.load("tsne.sav") + cf_s = px.scatter( + cf_df, + x="0", + y="1", + color="diagnosis", + color_continuous_scale=px.colors.sequential.Rainbow, + ) + + cf_s.update_traces( + marker=dict( + size=10, + symbol="circle", + ) + ) + + tsne.add_trace(cf_s.data[0]) + pickle.dump(tsne, open("tsne_cfs.sav", "wb")) + tsne = tsne.to_html() else: - if os.path.exists(excel_file_name) == False: excel_file_name = "dataset.csv" request.session["excel_file_name"] = excel_file_name - df = pd.read_csv(excel_file_name) - fig2 = None + df = pd.read_csv(excel_file_name) # just random columns to plot - feature1 = df.columns[0] - feature2 = df.columns[1] + feature1 = df.columns[3] + feature2 = df.columns[2] request.session["feature1"] = feature1 request.session["feature2"] = feature2 fig = stats( @@ -237,10 +418,11 @@ def home(request): request.session["fig2"] = fig2 request.session["pca"] = pca request.session["tsne"] = tsne + request.session["clas_report"] = clas_report - data_to_display = df[:5].to_html() + data_to_display = df[:10].to_html(index=False) request.session["data_to_display"] = data_to_display - labels = df.columns + labels = df.columns[2:] context = { "data_to_display": data_to_display, "excel_file": excel_file_name, @@ -252,6 +434,9 @@ def home(request): "clas_report": clas_report, "pca": pca, "tsne": tsne, + "cfrow_og": cfrow_og, + "cfrow_cf": cfrow_cf, + "cfdf_rows": cfdf_rows, } return render(request, "base/home.html", context) @@ -278,6 +463,9 @@ def stats(name, feature1, feature2): else: # they both are categorical so do scatter fig = px.histogram(df, x=feature1, color=feature2) + + fig.update_layout(clickmode="event+select") + fig = fig.to_html(full_html=False) return fig @@ -285,27 +473,30 @@ def stats(name, feature1, feature2): def preprocess(data, value_list, name): from sklearn.preprocessing import StandardScaler + ids = data["id"] + data = data.drop(["id"], axis=1) for type in value_list: if type == "std": # define standard scaler scaler = StandardScaler() y = data["diagnosis"] - if is_numeric_dtype(data["diagnosis"]): - # if class column is numeric do not - # apply preprocessing - data = data.drop(["diagnosis"], axis=1) + # if class column is numeric do not + # apply preprocessing + data = data.drop(["diagnosis"], axis=1) # transform data cols = data.select_dtypes(np.number).columns data[cols] = scaler.fit_transform(data[cols]) - y = y.to_frame() - data = data.join(y) + data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False) if type == "onehot": data = pd.get_dummies(data) if type == "imp": + y = data["diagnosis"] + data = data.drop(["diagnosis"], axis=1) + data_numeric = data.select_dtypes(exclude=["object"]) data_categorical = data.select_dtypes(exclude=["number"]) imp = SimpleImputer(missing_values=np.nan, strategy="most_frequent") @@ -318,17 +509,20 @@ def preprocess(data, value_list, name): [data_numeric, data_categorical], axis=1, ignore_index=False ) + data = pd.concat([y.to_frame(), data], axis=1, ignore_index=False) + new = pd.concat([ids.to_frame(), data], axis=1, ignore_index=False) os.remove(name) - data.to_csv(name, index=False) + new.to_csv(name, index=False) return def training(name, type, test_size=0.7): data = pd.read_csv(name) + X = data.drop("id", axis=1) from sklearn.model_selection import train_test_split - y = data["diagnosis"] - X = data.drop("diagnosis", axis=1) + y = X["diagnosis"] + X = X.drop("diagnosis", axis=1) X_train, X_test, y_train, y_test = train_test_split( X, y, shuffle=True, test_size=test_size, stratify=y, random_state=42 ) @@ -377,6 +571,16 @@ def training(name, type, test_size=0.7): importance = svc.coef_[0] model = svc + if "rf" == type: + from sklearn.ensemble import RandomForestClassifier + + rf = RandomForestClassifier() + rf.fit(X_train, y_train) + y_pred = rf.predict(X_test) + filename = "rf.sav" + importance = rf.feature_importances_ + model = rf + clas_report = classification_report(y_test, y_pred, output_dict=True) clas_report = pd.DataFrame(clas_report).transpose() clas_report = clas_report.sort_values(by=["f1-score"], ascending=False) @@ -423,6 +627,13 @@ def testing(name, type): importance = svc.coef_[0] model = svc + if "rf" == type: + filename = "rf.sav" + rf = joblib.load(filename) + y_pred = rf.predict(X_test) + importance = rf.feature_importances_ + model = rf + clas_report = classification_report(y_test, y_pred, output_dict=True) clas_report = pd.DataFrame(clas_report).transpose() clas_report = clas_report.sort_values(by=["f1-score"], ascending=False) @@ -433,3 +644,73 @@ def testing(name, type): "clas_report": clas_report, } return con + + +def counterfactuals(row, model, preprocessed_file): + df = pd.read_csv(preprocessed_file) + df = df.drop("id", axis=1) + row = row.drop(["diagnosis", "id"], axis=1) + if "logit" == model: + filename = "lg.sav" + clf = joblib.load(filename) + model = clf + + if "xgb" == model: + filename = "xgb.sav" + xgb = joblib.load(filename) + model = xgb + + if "dt" == model: + filename = "dt.sav" + dt = joblib.load(filename) + model = dt + + if "svm" == model: + filename = "svc.sav" + svc = joblib.load(filename) + model = svc + + if "rf" == model: + filename = "rf.sav" + rf = joblib.load(filename) + model = rf + + d = dice_ml.Data( + dataframe=df, + continuous_features=[ + "radius_mean", + "texture_mean", + "smoothness_mean", + "compactness_mean", + "concavity_mean", + "symmetry_mean", + "fractal_dimension_mean", + "radius_se", + "texture_se", + "smoothness_se", + "compactness_se", + "concavity_se", + "concave points_se", + "symmetry_se", + "fractal_dimension_se", + "compactness_worst", + "concavity_worst", + "concave points_worst", + "fractal_dimension_worst", + ], + outcome_name="diagnosis", + ) + + m = dice_ml.Model(model=model, backend="sklearn") + exp = dice_ml.Dice(d, m) + dice_exp = exp.generate_counterfactuals( + row, # The data from the 1st row of our dataframe + total_CFs=3, # Total number of Counterfactual Examples we want to print out. There can be multiple. + desired_class="opposite", # We want to convert the quality to the opposite one. + ) + + # save cfs + dice_exp.cf_examples_list[0].final_cfs_df.to_csv( + path_or_buf="counterfactuals.csv", index=False + ) + return dice_exp._cf_examples_list diff --git a/db.sqlite3 b/db.sqlite3 index fc5ba3736..3b5d8342d 100644 Binary files a/db.sqlite3 and b/db.sqlite3 differ diff --git a/extremum/__pycache__/settings.cpython-310.pyc b/extremum/__pycache__/settings.cpython-310.pyc index 3167a7003..d1a62b6c1 100644 Binary files a/extremum/__pycache__/settings.cpython-310.pyc and b/extremum/__pycache__/settings.cpython-310.pyc differ diff --git a/extremum/routing.py b/extremum/routing.py new file mode 100644 index 000000000..2211dd6f7 --- /dev/null +++ b/extremum/routing.py @@ -0,0 +1,3 @@ +from channels.routing import ProtocolTypeRouter + +application = ProtocolTypeRouter({}) \ No newline at end of file diff --git a/extremum/settings.py b/extremum/settings.py index 6baff807e..52f3b11f7 100644 --- a/extremum/settings.py +++ b/extremum/settings.py @@ -40,6 +40,9 @@ INSTALLED_APPS = [ "django.contrib.staticfiles", "base.apps.BaseConfig", "bootstrap5", + "django_plotly_dash.apps.DjangoPlotlyDashConfig", + "channels", + "channels_redis", ] MIDDLEWARE = [ @@ -105,6 +108,7 @@ AUTH_PASSWORD_VALIDATORS = [ # Internationalization # https://docs.djangoproject.com/en/5.0/topics/i18n/ +X_FRAME_OPTIONS = 'SAMEORIGIN' LANGUAGE_CODE = "en-us" @@ -114,6 +118,32 @@ USE_I18N = True USE_TZ = True +CRISPY_TEMPLATE = "bootsrap5" + +ASGI_APPLICATION = "extremum.routing.applications" + +CHANNEL_LAYERS = { + 'default': { + 'BACKEND': 'channels_redis.core.RedisChannelLayer', + 'CONFIG': { + 'hosts': [('127.0.0.1', 6379),], + }, + }, +} + +SSTATICFILRES_FINDERS = [ + "django.contrib.staticfiles.finders.FileSystemFinder", + "django.contrib.staticfiles.finders.AppDirectoriesFinder", + "django_plotly_dash.finder.DashAssetFinder", + "dajango_plotly_dash.finders.DashComponentsFinder", +] + +PLOTLY_COMPONENTS = [ + "dash_core_components", + "dash_html_components", + "dash_renderer", + "dpd_components", +] # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/5.0/howto/static-files/ diff --git a/templates/main.html b/templates/main.html index 72c2eff95..d0c598b80 100644 --- a/templates/main.html +++ b/templates/main.html @@ -11,7 +11,10 @@ <title>EXTREMUM</title> <meta name="viewport" content="'width=device-width, initial-scale=1" /> <link rel="stylesheet" href="{% static 'css/style.css' %}"> - <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script> + <script src="https://cdn.plot.ly/plotly-latest.min.js" charset="utf-8"></script> + <script src="https://code.jquery.com/jquery-3.7.1.js" integrity="sha256-eKhayi8LEQwp4NKxN+CfCh+3qOVUtJn3QNZ0TciWLP4=" crossorigin="anonymous"></script> + <script src="https://cdnjs.cloudflare.com/ajax/libs/svg.js/3.2.0/svg.min.js" charset="utf-8"></script> + <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"> </head> @@ -29,12 +32,13 @@ <br> <br> <body onscroll="setScroll()" onload="setScreen()"> - {% block content %} {% endblock content %} - + <script src="{% static 'js/click_on_graph.js' %}"></script> <script src="{% static 'js/hide_seek.js' %}"></script> <script src="{% static 'js/slider.js' %}"></script> <script src="{% static 'js/keep_scroll_on_load.js' %}"></script> + <script src="{% static 'js/counterfactuals.js' %}"></script> + <script src="{% static 'js/click_reset.js' %}"></script> </body> </html>