Simple Linear Regression from scratch in Rust

As one of the oldest and easiest Machine Learning algorithms, implementing Simple Linear Regression can be an eye-opening and rewarding experience for anyone new to Machine Learning, Deep Learning and AI.

In this tutorial, we are going to implement Simple Linear Regression in Rust. To really internalize the algorithm, we won’t use any existing math or Machine Learning frameworks. The only external crate we will use is plotlib, which will allow us to visualize the results in a nice svg graphic.
You can find the source code for the entire project on


  • basic knowledge of the Rust programming language
  • basic high school math / statistic skills

Simple Linear Regression

The Linear Regression algorithm allows us to predict a dependent variable y based on a set of independent variables x0,x1..xn.
Based on a set of existing (x, y) pairs, our goal is to create a prediction function y(x):

y(x1..xn) = b0 + b1 * x1 + b2 * x2 + .. + bn * xn

where b0 is the so called intercept (y at x==0) and are the coefficients that will be applied to the input values.

In the case of Simple Linear Regression, we have only one independent variable x, so we can simplify the above function to:

y(x) = b0 + b1 * x

Once we figure out the coefficient and the intercept, we can use them to make predictions for any new value of x by simply solving the equation. For example, if b0 == 5 and b1 == 2, y(4) can be calculated like this:

y(4) = 5 + 2 * 4 = 13

Example data set

To test our algorithm, I made up the following simple dataset:

[(1, 1), (2, 3), (3, 2), (4, 3), (5, 5)]

which can be visualized in the following scatter plot:

Our goal is, to draw a straight line that is as close as possible to each of these points. Then we can use that line to predict the y value for any x.

Estimating the intercept and coefficient

To estimate b0 and b1, we can use the following statistical equations:

b1 = Cov(x, y) / Var(x)
b0 = mean(y) - b1 * mean(x)

where Cov is the covariance, Var is the variance and mean is the mean of an array.
We can calculate all three of them with the following formulas:

mean(x) = sum(x) / length(x)
Var(x) = sum((x - mean(x))^2)
Cov = sum((x[i] - mean(x)) * (y[i] - mean(y)))

To learn more about the math behind this algorithm, take a look at Simple Linear Regression.

Project structure

With all the basics layed out, let’s jump right into the code. I structured the project like this:

Our Linear Regression model will be implemented inside regression\ It will make use of utils\, which contains statistic helper functions to calculate the mean, variance and covariance as defined above.
(The module contains a helper function for unit tests, feel free to take a look at it in your own time).

To make our code usable as a library, we will reference the linear_regression module inside the file and then use it as crate from

Implementing mean, Variance and Covariance

Let’s start by implementing the 3 math operations described above, as they make up the core of our algorithm. We will put them inside the file:

pub fn mean(values : &Vec<f32>) -> f32 {
    if values.len() == 0 {
        return 0f32;

    return values.iter().sum::<f32>() / (values.len() as f32);

pub fn variance(values : &Vec<f32>) -> f32 {
    if values.len() == 0 {
        return 0f32;

    let mean = mean(values);
    return values.iter()
            .map(|x| f32::powf(x - mean, 2 as f32))
            .sum::<f32>() / values.len() as f32;

pub fn covariance(x_values : &Vec<f32>, y_values : &Vec<f32>) -> f32 {
    if x_values.len() != y_values.len() {
        panic!("x_values and y_values must be of equal length.");

    let length : usize = x_values.len();
    if length == 0usize {
        return 0f32;

    let mut covariance : f32 = 0f32;
    let mean_x = mean(x_values);
    let mean_y = mean(y_values);

    for i in 0..length {
        covariance += (x_values[i] - mean_x) * (y_values[i] - mean_y)

    return covariance / length as f32;        

Note that we return 0 when an empty Vec is provided to either method. If size(x)!=size(y) in the covariance method, the code will break and panic. Beside that, the code should be pretty self-explanatory, it simply performs the calculations we defined above.

Implementing Simple Linear Regression

With our statistical functions ready, let’s dive into the actual algorithm. The interface of our Linear Regression model will look like this:

pub struct LinearRegression {
    pub coefficient: Option<f32>,
    pub intercept: Option<f32>

impl LinearRegression {
    pub fn new() -> LinearRegression { .. }
    pub fn fit(&mut self, x_values : &Vec<f32>, y_values : &Vec<f32>) { .. }
    pub fn predict_list(&self, x_values : &Vec<f32>) -> Vec<f32> { .. }
    pub fn predict(&self, x : f32) -> f32 { .. }
    pub fn evaluate(&self, x_test : &Vec<f32>, y_test: &Vec<f32>) -> f32 { ..}

The struct contains two properties to get the intercept and coefficient of our model (b0 and b1). To initialize them, we will have to create an instance with the new function and then call it’s fit method (by default, both will be None).
Afterward, we can use the other methods to make new predictions and evaluate the performance of our model (using Root Mean Squared Error).

Let’s go over these methods one by one:

pub fn new() -> LinearRegression {
    LinearRegression { coefficient: None, intercept: None }

Nothing fancy going on here, we return a new instance of our struct with both the coefficient and intercept initialized to None.

Let’s dig into the fit function. Remember, that we can calculate b0 and b1 using statistics:

b1 = Cov(x, y) / Var(x)
b0 = mean(y) - b1 * mean(x)

Translated into Rust, our code will look like this:

pub fn fit(&mut self, x_values : &Vec<f32>, y_values : &Vec<f32>) {
    let b1 = stat::covariance(x_values, y_values) / stat::variance(x_values);
    let b0 = stat::mean(y_values) - b1 * stat::mean(x_values);

    self.intercept = Some(b0);
    self.coefficient = Some(b1);       

All the heavy lifting is already implemented in, so this code looks very straightforward. Remeber to add use utils::stat; at the top of the file to make it compile.

To make predictions for new values of x, we can use the equation for Linear regression defined above:

y(x) = b0 + b1 * x

In Rust, this can be achieved easily:

pub fn predict(&self, x : f32) -> f32 {
    if self.coefficient.is_none() || self.intercept.is_none() {
        panic!("fit(..) must be called first");

    let b0 = self.intercept.unwrap();
    let b1 = self.coefficient.unwrap();

    return b0 + b1 * x;

We first check if either the intercept (b0) and the coefficient (b1) is not initialized yet and display an error message in that case. Otherwise, we get b0 and b1 by unwrapping both properties and return the result of calculating y(x).

To make predictions for a list of inputs, I added an additional predict_list method:

pub fn predict_list(&self, x_values : &Vec<f32>) -> Vec<f32> {
    let mut predictions = Vec::new();

    for i in 0..x_values.len() {

    return predictions;

Here, we iterate over all input elements, predict their y-value and add it our list of predictions which will then be returned.

Performance evaluation

All that is left now is the evaluate function, which will tell us how accurate our model is. Like mentioned above, we will use the Root Mean Squared Error method.
We can calculate the Root Mean Squared Error with the following formulas:

mse = sum((precition[i] - actual[i])^2)
rmse = sqrt(mse)

In Rust, this can be represented by the following function:

fn root_mean_squared_error(&self, actual : &Vec<f32>, predicted : &Vec<f32>) -> f32 {
    let mut sum_error = 0f32;
    let length = actual.len();

    for i in 0..length {
        sum_error += f32::powf(predicted[i] - actual[i], 2f32);

    let mean_error = sum_error / length as f32;
    return mean_error.sqrt();

Before we can use this function, however, we have to predict the values for our test set. We will do this inside the evaluate function, which then returns the result of the root_mean_squared_error function:

pub fn evaluate(&self, x_test : &Vec<f32>, y_test: &Vec<f32>) -> f32 {
    if self.coefficient.is_none() || self.intercept.is_none() {
        panic!("fit(..) must be called first");

    let y_predicted = self.predict_list(x_test);
    return self.root_mean_squared_error(y_test, &y_predicted);

Again, if either the coefficient or the intercept isn’t initialized, the code will break with an error message. Otherwise, we call the predict_list method to create a list of predictions and pass it to our rmse function.

Testing the algorithm

With everything set up, it’s time to test our new Linear Regression algorithm! In the file, we can create a public main() function and initialize our model like this:

let mut model = linear_regression::LinearRegression::new();
let x_values = vec![1f32, 2f32, 3f32, 4f32, 5f32];
let y_values = vec![1f32, 3f32, 2f32, 3f32, 5f32];, &y_values);

We can use the methods we created to display the coefficient, intercept and accuracy:

println!("Coefficient: {0}", model.coefficient.unwrap());
println!("Intercept: {0}", model.intercept.unwrap());
println!("Accuracy: {0}", model.evaluate(&x_values, &y_values));

which will print:

Coefficient: 0.8
Intercept: 0.39999986
Accuracy: 0.69282037

To make predictions, we can use both our prediction methods:

let y_predictions : Vec<f32> = model.predict_list(&x_values);
let y_prediction : f32 = model.predict(4);

Visualizing results

To see how well our algorithm performed without crunching numbers, it will be nice to get a visual representation of our predictions by plotting them into a 2D coordinate system. We can do so using the plotlib library:

let plot_actual = Scatter::from_vec(&actual)

let plot_prediction = Scatter::from_vec(&y_prediction)

let v = View::new()
    .x_range(-0., 6.)
    .y_range(0., 6.)


Don’t forget to import the crate and add the according use directives:

extern crate plotlib;

use plotlib::scatter::Scatter;
use plotlib::scatter;
use plotlib::style::{Marker, Point};
use plotlib::view::View;
use plotlib::page::Page;

This will save a plot with the actual values and our predictions in a nicely formatted .svg file:

As you can see, our prediction matches the actual result quite well for most values. The only outliers are at x==2 and x==3, though both are still within acceptable bounds.


This concludes our trip to the Linear Regression algorithm, feel free to check out the source code and play around with different input values. Let me know if you have any questions, hints or comments.