Fitting a straight line in Julia: Flux machine learning

julia
Flux
linear model
Author

Jong-Hoon Kim

Published

January 4, 2024

Fitting a straight line in Julia

This post is my attempt to learn machine learning in Julia. The contents of this page came from the Flux. Flux is a machine learning package written in Julia.

Create training and test data

using Flux, Distributions, Random, Statistics
# create the data
# true parameter values are 4 and 2
linear_model(x) = rand(Normal(4x+2,1))[1]
linear_model (generic function with 1 method)

x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 … 4 5], [6 7 … 9 10])
y_train, y_test = linear_model.(x_train), linear_model.(x_test)
([2.4026932318407495 5.949285185429201 … 17.316799499402293 20.00123315299341], [26.267203000424637 29.3186325205348 … 40.34823197526444 42.919475237241265])

Create a neural network model with a single layer

# y = σ.(W * x .+ bias) 
model = Flux.Dense(1 => 1) # 2 parameters
Dense(1 => 1)       # 2 parameters
# prediction based on the untrained baseline model 
model(x_train)
1×6 Matrix{Float32}:
 0.0  -1.66877  -3.33753  -5.0063  -6.67506  -8.34383

# define the loss function to use it for training
loss(m, x, y) = mean(abs2.(m(x) .- y))
loss (generic function with 1 method)
loss(model, x_train, y_train)
335.3910074651219

Train the model

Flux package has the Flux.train! function for model training. The function requires an optimizer, which can be, for example, created using Descent function.

using Flux: train!

opt = Descent() # gradient descent algorithm 
Descent(0.1)
Descent(0.1)
Descent(0.1)

data = [(x_train, y_train)]
1-element Vector{Tuple{Matrix{Int64}, Matrix{Float64}}}:
 ([0 1 … 4 5], [2.4026932318407495 5.949285185429201 … 17.316799499402293 20.00123315299341])

train!(loss, model, data, opt)
loss(model, x_train, y_train)
318.3230947174688

# check model parameters after going through the data once
model.weight, model.bias
(Float32[9.375584;;], Float32[3.186498])

# iteratively train the model
for epoch in 1:200
    train!(loss, model, data, opt)
end

Examine the trained model

# check the loss
loss(model, x_train, y_train)
0.3441841968180955
# check the model parameters
model.weight, model.bias
(Float32[3.637514;;], Float32[2.748934])
# check the model against the test data set
model(x_test)
1×5 Matrix{Float32}:
 24.574  28.2115  31.849  35.4866  39.1241
y_test
1×5 Matrix{Float64}:
 26.2672  29.3186  34.5149  40.3482  42.9195