{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Training an Anomaly Detection Model for Covid Anomaly Detection\n", "\n", "## Overview\n", "\n", "In this tutorial, we will train an anomaly detection model using a simple [LSTM-AutoEncoder model](https://www.medrxiv.org/content/10.1101/2021.01.08.21249474v1).\n", "Data can be obtained from [this link](https://iscteiul365-my.sharepoint.com/:u:/g/personal/oonia_iscte-iul_pt/ERZLm1ruUNpMqkSwjpqhE9wB_7loVWAC4yZWuIH2RKGOlQ?e=kD4HlI). This is a processed version of data from original Stanford dataset-Phase 2. The overall pre-processing pipeline used is illustrated in Figure below.\n", "\n", "![preprocessing](stanford_data_processing.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data was aquired from diferent sources (Germin, FitBit, Apple Watch) and pre-processed to have a common format. In this form, data has two columns: heart rate and number of user steps in last minute. \n", "\n", "The processing pipeline is then applied to the data. The pipeline is composed of the following steps:\n", "1. Once data was standardized, the resting heart rate was extracted (``Resting Heart Rate Extractor``, in Figure). This process takes as input `min_minutes_rest` that is the number of minutes that the user has to be at rest to consider the heart rate as resting. `min_minutes_rest` variable looks at user steps and, when user steps is 0 for `min_minutes_rest` minutes, the heart rate is considered as resting. At the end of this process, we will have a new dataframe with: the date (`datetime` column) and the resting heart rate of the last minute (`RHR` column).\n", "\n", "2. The smoother process is applied to the data (`Smoother`, in Figure). This process takes as input `smooth_window_sample` that is the number of samples that we will use to smooth the data, and `sample_rate` that is the sample rate. The smoother process will apply a moving average filter to the data, with a window of `smooth_window_sample` samples. Then the data is downsampled to `sample_rate` samples per minute. This process will produce a new dataframe with the date (`datetime` column), the resting heart rate at desired sampling rate (`RHR` column).\n", "\n", "3. The second step is adding labels (`Label Adder`, in Figure). Is is also illustrated in Figure below. This process takes 3 inputs: `baseline_days`, `before_onset`, and `after_onset`. The `baseline_days` is the number of days before the onset of the symptoms that we consider as baseline (in figure below, this is 21 days). Thus, using the dataframe from last step, a new column named `baseline` is added, which is a boolean column that is True if the date is in the baseline period (21 days before onset). The `before_onset` and `after_onset` are the number of days before and after the onset of the symptoms that we consider as the anomaly period (7 days before and 21 days before, in Figure below). A new column named `anomaly` is added, which is a boolean column that is True if the date is in the anomaly period. Finnaly, we also add a `status` column,that is a metadata column for a descriptive status of the date. If can be: \n", " - `normal`: if the date is in the baseline period; \n", " - `before onset`: if the date is in the period before the onset of the symptoms; \n", " - `onset` if the date is the onset of the symptoms (day); \n", " - `after onset` if the date is in the period after the onset of the symptoms, but before the recovery; \n", " - `recovered` if the date is in the recovery period.\n", "\n", "4. Once the labels were added we normalize the data (`Standardizer` in Figure above). This process perform a Z-norm scale on the data. The Z-norm scale is calculated as: $z = \\frac{x - \\mu}{\\sigma}$, where $x$ is the value, $\\mu$ is the mean of the column and $\\sigma$ is the standard deviation of the column. An important note here is that the mean and standard deviation are calculated only for the baseline period, and then applied to the entire dataset.\n", "\n", "5. The last step is to create the sequences (`Transposer`, in Figure), that will group $n$ rows and transform it into columns (features). This process takes as input `window_size` and `overlap` parameters and creates sequences of `window_size` samples with an overlap of `overlap` samples. Thus, if we have a dataset with 100 samples, a `window_size` of 20 and an `overlap` of 0, we will have 5 sequences of 20 samples each (*i.e.* 5 rows with 20 columns). Each element of the sequence will be a column in the dataframe, numbered from 0 to 19. Thus, for example, the sequences will have columns `RHR-0`, `RHR-1`, ..., `RHR-19`, where the first row is the first 20 samples, the second row is the second 20 samples, and so on. This is useful as it is the format that the LSTM-AutoEncoder model expects as input. An important note is that we do not mix sequences from anomaly and non-anomaly periods. Thus, no label is mixed, that is, an anomaly sample only has anomaly time-steps.\n", "\n", "This will produce a dataframe (CSV file) for each user. In processed dataset, we joined all users in a single file and add a column `participant_id` to identify the user. This makes easier to work with the data in the next steps." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![labeling](anomaly_periods.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already generated several files, with different parameters and operations of the pre-processing pipeline:\n", "* `rhr_df`: dataframe with the resting heart rate without normalization (step 4) and transposing (step 5). The `min_minutes_rest` is 12, `smooth_window_sample` is 400, `sample_rate` is 1 hour, `baseline_days` is 21, `before_onset` is 7, and `after_onset` is 21.\n", "* `rhr_df_scaled`: same as `rhr_df`, but with normalization.\n", "* `windowed_16_overlap_0_rate_10min_df`: same dataframe as `rhr_df` with the resting heart rate normalized (step 4) and transposed (step 5). The `window_size` is 16, `overlap` is 0, and `sample_rate` is 10 minutes.\n", "* `windowed_16_overlap_0_rate_10min_scaled_df`: same dataframe as `windowed_16_overlap_0_rate_10min_df`, but with normalization.\n", "\n", "**NOTE**: The files follows this naming convention: `windowed_{window_size}_overlap_{overlap}_rate_{sample_rate}_df.csv`. If sample_rate is ommited, it is, by default 1 hour.\n", "**NOTE**: The files may and with `fold_X`, where `X` is the fold number. This is used for cross-validation purposes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's import some libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from ssl_tools.data.data_modules.covid_anomaly import CovidUserAnomalyDataModule\n", "from ssl_tools.utils.data import get_full_data_split\n", "from ssl_tools.models.nets.lstm_ae import LSTMAutoencoder\n", "import lightning as L\n", "import torch\n", "import numpy as np\n", "from torchmetrics import MeanSquaredError" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load data and inspect" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datetimeRHR-0RHR-1RHR-2RHR-3RHR-4RHR-5RHR-6RHR-7RHR-8...RHR-10RHR-11RHR-12RHR-13RHR-14RHR-15anomalybaselinelabelparticipant_id
02027-01-14 21:00:001.1701750.653752-0.392374-1.431553-2.129013-2.755962-3.681322-4.674443-5.668570...-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099FalseTruenormalP110465
12027-01-15 05:00:00-5.668570-6.373289-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099-4.415848...-2.656756-1.305630-0.0727561.0461951.5304671.829053FalseFalsenormalP110465
22027-01-15 13:00:00-4.415848-3.467073-2.656756-1.305630-0.0727561.0461951.5304671.8290531.223064...-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738FalseFalsenormalP110465
32027-01-15 21:00:001.2230640.472444-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738-4.802627...-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843FalseFalsenormalP110465
42027-01-16 05:00:00-4.802627-5.831013-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843-0.062360...2.2669443.7944654.6257454.8277564.7200004.677464FalseFalsenormalP110465
..................................................................
317322024-12-13 00:00:00-0.180702-0.499793-0.749829-0.868485-0.966754-1.004670-0.888210-0.580762-0.467943...0.0920000.3478400.6363950.9581951.1705141.301841FalseFalserecoveredP992022
317332024-12-13 08:00:00-0.467943-0.1627400.0920000.3478400.6363950.9581951.1705141.3018411.477526...1.6603441.6566001.6856521.7472521.7673291.793616FalseFalserecoveredP992022
317342024-12-13 16:00:001.4775261.6573211.6603441.6566001.6856521.7472521.7673291.7936161.728615...1.5098331.3807491.2637441.1399971.0242050.946663FalseFalserecoveredP992022
317352024-12-14 00:00:001.7286151.6162651.5098331.3807491.2637441.1399971.0242050.9466631.136868...1.6421531.9093812.1144392.2822382.4536912.587843FalseFalserecoveredP992022
317362024-12-14 08:00:001.1368681.3804181.6421531.9093812.1144392.2822382.4536912.5878432.437232...2.3598402.1734002.0981401.9676691.7845121.561848FalseFalserecoveredP992022
\n", "

31737 rows × 21 columns

\n", "
" ], "text/plain": [ " datetime RHR-0 RHR-1 RHR-2 RHR-3 RHR-4 \\\n", "0 2027-01-14 21:00:00 1.170175 0.653752 -0.392374 -1.431553 -2.129013 \n", "1 2027-01-15 05:00:00 -5.668570 -6.373289 -6.937363 -7.102118 -6.975790 \n", "2 2027-01-15 13:00:00 -4.415848 -3.467073 -2.656756 -1.305630 -0.072756 \n", "3 2027-01-15 21:00:00 1.223064 0.472444 -0.424000 -1.145581 -1.355121 \n", "4 2027-01-16 05:00:00 -4.802627 -5.831013 -6.067744 -5.460156 -4.671143 \n", "... ... ... ... ... ... ... \n", "31732 2024-12-13 00:00:00 -0.180702 -0.499793 -0.749829 -0.868485 -0.966754 \n", "31733 2024-12-13 08:00:00 -0.467943 -0.162740 0.092000 0.347840 0.636395 \n", "31734 2024-12-13 16:00:00 1.477526 1.657321 1.660344 1.656600 1.685652 \n", "31735 2024-12-14 00:00:00 1.728615 1.616265 1.509833 1.380749 1.263744 \n", "31736 2024-12-14 08:00:00 1.136868 1.380418 1.642153 1.909381 2.114439 \n", "\n", " RHR-5 RHR-6 RHR-7 RHR-8 ... RHR-10 RHR-11 \\\n", "0 -2.755962 -3.681322 -4.674443 -5.668570 ... -6.937363 -7.102118 \n", "1 -6.554774 -6.112156 -5.396099 -4.415848 ... -2.656756 -1.305630 \n", "2 1.046195 1.530467 1.829053 1.223064 ... -0.424000 -1.145581 \n", "3 -2.321206 -3.124961 -3.928738 -4.802627 ... -6.067744 -5.460156 \n", "4 -3.408943 -2.237883 -1.187843 -0.062360 ... 2.266944 3.794465 \n", "... ... ... ... ... ... ... ... \n", "31732 -1.004670 -0.888210 -0.580762 -0.467943 ... 0.092000 0.347840 \n", "31733 0.958195 1.170514 1.301841 1.477526 ... 1.660344 1.656600 \n", "31734 1.747252 1.767329 1.793616 1.728615 ... 1.509833 1.380749 \n", "31735 1.139997 1.024205 0.946663 1.136868 ... 1.642153 1.909381 \n", "31736 2.282238 2.453691 2.587843 2.437232 ... 2.359840 2.173400 \n", "\n", " RHR-12 RHR-13 RHR-14 RHR-15 anomaly baseline label \\\n", "0 -6.975790 -6.554774 -6.112156 -5.396099 False True normal \n", "1 -0.072756 1.046195 1.530467 1.829053 False False normal \n", "2 -1.355121 -2.321206 -3.124961 -3.928738 False False normal \n", "3 -4.671143 -3.408943 -2.237883 -1.187843 False False normal \n", "4 4.625745 4.827756 4.720000 4.677464 False False normal \n", "... ... ... ... ... ... ... ... \n", "31732 0.636395 0.958195 1.170514 1.301841 False False recovered \n", "31733 1.685652 1.747252 1.767329 1.793616 False False recovered \n", "31734 1.263744 1.139997 1.024205 0.946663 False False recovered \n", "31735 2.114439 2.282238 2.453691 2.587843 False False recovered \n", "31736 2.098140 1.967669 1.784512 1.561848 False False recovered \n", "\n", " participant_id \n", "0 P110465 \n", "1 P110465 \n", "2 P110465 \n", "3 P110465 \n", "4 P110465 \n", "... ... \n", "31732 P992022 \n", "31733 P992022 \n", "31734 P992022 \n", "31735 P992022 \n", "31736 P992022 \n", "\n", "[31737 rows x 21 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Read CSV data\n", "data_path = \"/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv\"\n", "df = pd.read_csv(data_path)\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating a [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).\n", "* The first parameter is the path to the CSV file.\n", "* `participants`: is a list of participants to include in dataset. If nothing is passed, all participants in CSV are included.\n", "* `batch_size`: is the batch size to use in the dataloader.\n", "* `num_workers`: is the number of workers to use in the dataloader.\n", "* `reshape`: is the shape of the input data. For LSTM-AutoEncoder, it is `(sequence_length, num_features)`, or, in our case `(16, 1)`\n", "\n", "**NOTE**: The training data is only data where baseline is True. The test data will be only data where baseline is False." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CovidUserAnomalyDataModule (Data=/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv, 1 participant selected)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dm = CovidUserAnomalyDataModule(\n", " data_path,\n", " participants=[\"P992022\"],\n", " batch_size=32,\n", " num_workers=0,\n", " reshape=(16, 1),\n", ")\n", "dm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create the lightning model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LSTMAutoencoder(\n", " (backbone): _LSTMAutoEncoder(\n", " (lstm1): LSTM(1, 128, batch_first=True)\n", " (lstm2): LSTM(128, 64, batch_first=True)\n", " (repeat_vector): Linear(in_features=64, out_features=1024, bias=True)\n", " (lstm3): LSTM(64, 64, batch_first=True)\n", " (lstm4): LSTM(64, 128, batch_first=True)\n", " (time_distributed): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", " (loss_fn): MSELoss()\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = LSTMAutoencoder(input_shape=(16, 1))\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creting Trainer" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = L.Trainer(max_epochs=100, devices=1, accelerator=\"cpu\")\n", "trainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fit the model using training data from the datamodule" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params\n", "----------------------------------------------\n", "0 | backbone | _LSTMAutoEncoder | 316 K \n", "1 | loss_fn | MSELoss | 0 \n", "----------------------------------------------\n", "316 K Trainable params\n", "0 Non-trainable params\n", "316 K Total params\n", "1.264 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "122a71df981c48c183eb2b4e7585103d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00 anomaly_threshold else 0 for loss in losses]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
truepredictedlossanomaly_threshold
0000.0237000.374275
1000.0914130.374275
2000.0542990.374275
3000.0074860.374275
4000.0246010.374275
...............
89100.0898330.374275
90100.0515620.374275
91100.1327480.374275
92100.1586100.374275
93100.0255220.374275
\n", "

94 rows × 4 columns

\n", "
" ], "text/plain": [ " true predicted loss anomaly_threshold\n", "0 0 0 0.023700 0.374275\n", "1 0 0 0.091413 0.374275\n", "2 0 0 0.054299 0.374275\n", "3 0 0 0.007486 0.374275\n", "4 0 0 0.024601 0.374275\n", ".. ... ... ... ...\n", "89 1 0 0.089833 0.374275\n", "90 1 0 0.051562 0.374275\n", "91 1 0 0.132748 0.374275\n", "92 1 0 0.158610 0.374275\n", "93 1 0 0.025522 0.374275\n", "\n", "[94 rows x 4 columns]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_dataframe = pd.DataFrame(\n", " {\n", " \"true\": y_test,\n", " \"predicted\": y_test_hat,\n", " \"loss\": losses,\n", " \"anomaly_threshold\": anomaly_threshold,\n", " }\n", ")\n", "\n", "results_dataframe" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualizing Metrics and Confusion Matrix" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F1-score: 0.0\n", "Recall: 0.0\n", "Balanced Accuracy: 0.5\n", "ROC AUC: 0.5\n" ] } ], "source": [ "from sklearn.metrics import f1_score, recall_score, balanced_accuracy_score, roc_auc_score\n", "\n", "# Extract true and predicted labels from the results_dataframe\n", "true_labels = results_dataframe['true']\n", "predicted_labels = results_dataframe['predicted']\n", "\n", "# Calculate the F1-score\n", "f1 = f1_score(true_labels, predicted_labels)\n", "\n", "# Calculate the recall\n", "recall = recall_score(true_labels, predicted_labels)\n", "\n", "# Calculate the balanced accuracy\n", "balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)\n", "\n", "# Calculate the ROC AUC\n", "roc_auc = roc_auc_score(true_labels, predicted_labels)\n", "\n", "# Print the results\n", "print(\"F1-score:\", f1)\n", "print(\"Recall:\", recall)\n", "print(\"Balanced Accuracy:\", balanced_acc)\n", "print(\"ROC AUC:\", roc_auc)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "from sklearn.metrics import confusion_matrix\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "# Get the true and predicted labels from the results_dataframe\n", "true_labels = results_dataframe['true']\n", "predicted_labels = results_dataframe['predicted']\n", "\n", "# Compute the confusion matrix\n", "cm = confusion_matrix(true_labels, predicted_labels)\n", "\n", "# Define the class labels\n", "class_labels = ['Normal', 'Anomaly']\n", "\n", "# Plot the confusion matrix\n", "plt.figure(figsize=(8, 6))\n", "plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", "plt.title('Confusion Matrix')\n", "plt.colorbar()\n", "tick_marks = np.arange(len(class_labels))\n", "plt.xticks(tick_marks, class_labels, rotation=45)\n", "plt.yticks(tick_marks, class_labels)\n", "plt.xlabel('Predicted Label')\n", "plt.ylabel('True Label')\n", "\n", "# Add the values to the confusion matrix plot\n", "thresh = cm.max() / 2.\n", "for i in range(cm.shape[0]):\n", " for j in range(cm.shape[1]):\n", " plt.text(j, i, format(cm[i, j], 'd'),\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 2 }