library(torch)
library(tidyverse)
# install.packages("kmed")
library(kmed)
set.seed(1234)
<- heart %>%
heart_disease as_tibble() %>%
mutate(diagnosis = if_else(class == 0, F, T)) %>%
select(-class) %>%
mutate(train = sample(0:1, size = nrow(heart), replace = T, prob = c(1/3, 2/3)))
7.3 Heart Disease Classification Problem with Neural Networks
In this assignment, you’ll analyze heart disease data using exploratory data analysis, linear regression, and neural networks. This will give you hands-on experience with a real-world classification problem.
Setup
First, load the necessary libraries and data:
Variable | Description |
---|---|
age | years |
sex | T for male; F for female |
cp | chest pain type: |
1: typical angina | |
2: atypical angina | |
3: non-anginal pain | |
4: asymptomatic | |
trestbps | resting blood pressure |
chol | serum cholesterol in mg/dl |
fbs | fasting blood sugar: > 120 mg/dl, T or F |
restecg | 0 = normal |
1 = ST-T wave abnormality | |
2 = left ventricular hypertrophy | |
thalach | maximum heart rate achieved during exercise |
exang | exercise induced angina (T/F) |
oldpeak | ST depression induced by exercise |
slope | slope of the peak exercise ST segment: 1: upsloping, 2: flat, 3: downsloping |
ca | number of major vessels colored by fluoroscopy |
thal | 3: normal, 6: fixed defect, 7: reversable defect |
class | diagnosis of heart disease (1-4: yes; 0: no) |
dplyr and ggplot2
Plot the age distribution for people diagnosed with heart disease. Are women diagnosed at older ages than men?
How common is each type of chest pain, and which types are more likely to indicate heart disease?
Is resting blood pressure correlated with heart disease? Compare men versus women.
Is cholesterol correlated with heart disease? Compare men versus women.
Is fasting blood pressure correlated with heart disease? Compare men versus women.
Linear Regression
Fit a linear probability model to predict heart disease diagnosis and evaluate its performance using the training data set versus the test data set.
Preparing Data for Neural Network
Neural networks are sensitive to the scale and format of input data. Let’s prepare our variables appropriately.
a) Interval data
Scale these variables by subtracting their means and dividing by their standard deviations: age, trestbps, chol, thalach, oldpeak
b) One-Hot Encoding for Categorical data
These categorical variables should be one-hot encoded for neural networks: cp, restecg, slope, ca, thal
.
%>%
heart_disease select(cp, restecg, slope, ca, thal) %>%
model.matrix(~ . - 1, data = .)
Reflect: what is “one-hot coding”?
c) Binary variables
These binary variables can be converted to 0/1 format: diagnosis, exang, fbs, sex
.
d) Tensors
Create tensors x, x_test, y, and y_test.
Neural Network
Use your function from 7.2 to train a neural network on the training data set. Experiment with learning rates, the number of epochs, and the number of hidden units. Then test its performance on the test data set. How does the neural network compare to the linear probability model from earlier in this assignment?