2023-11-16: Analysing Iris Data
As I've done in a few other notes, I will explore some simple data analysis tools applied to the iris dataset, but this time using Rust.
This example uses the clustering, csv and statistical crates.
Loading a CSV File
The iris dataset, "iris.data", is a CSV formatted dataset:
5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa 4.6,3.1,1.5,0.2,Iris-setosa ...
There are four numeric fields, followed by a class name.
The first part of the program must load the dataset, creating a vector of "IrisInstance" instances, where each "IrisInstance" holds the data from one row of the file.
The "IrisInstance" struct is:
struct IrisInstance { sepal_length: f64, sepal_width: f64, petal_length: f64, petal_width: f64, label: String, }
We write a function to read the data from a given filename and return a vector of "IrisInstance" instances:
fn read_iris_data(filename: &str) -> Vec<IrisInstance> { csv::ReaderBuilder::new() // <1> .has_headers(false) // <2> .from_path(filename) // <3> .expect("Could not read file") .records() // <4> .flatten() // remove possible record error <5> .flat_map(|record| IrisInstance::from_csv_record(&record)) // <6> .collect::<Vec<IrisInstance>>() // <7> }
- We need to build a csv reader because ...
- ... we do not have any headers in the file,
- and we want the reader to work on the given filename.
-
At this point, we have a reader and can access the
records
iterator. - First, remove any error cases in the records.
- Then try to convert each record into an instance of "IrisInstance" (see below)
- Collect the converted instances into a vector and return it, with the caller taking on ownership.
The from_csv_record
function is implemented on IrisInstance
. It makes a few
checks, to be sure that the record has the right number of fields, and the
required set of numbers and label, returning an Option
type because any
errors will lead to it returning None
:
impl IrisInstance { fn from_csv_record(record: &csv::StringRecord) -> Option<IrisInstance> { if record.len() == 5 { let sepal_length = record.get(0).expect("expected a number") .parse::<f64>().unwrap_or_default(); let sepal_width = record.get(1).expect("expected a number") .parse::<f64>().unwrap_or_default(); let petal_length = record.get(2).expect("expected a number") .parse::<f64>().unwrap_or_default(); let petal_width = record.get(3).expect("expected a number") .parse::<f64>().unwrap_or_default(); Some(IrisInstance { sepal_length, sepal_width, petal_length, petal_width, label: String::from(record.get(4).expect("expected a string")), }) } else { None } } }
At this stage, we can read in the "iris.data" file and check that we get 150 instances:
fn main() { let instances = read_iris_data("iris.data"); println!("Read {} instances.", instances.len()); }
Output:
Read 150 instances.
Descriptive Statistics
The idea here is to print some information about the attributes. To do this, we need to convert the values for each attribute into vectors, across all the instances. As we analyse four attributes in the same way, a closure is used to access the attribute information:
fn print_statistics<F>(instances: &[IrisInstance], attribute_name: &str, attribute_value: F) where F : Fn(&IrisInstance) -> f64, // <1> { let values: Vec<f64> = instances.iter().map(attribute_value).collect(); // <2> println!("{}", attribute_name); println!(" -- Minimum: {:.2}", values.iter().min_by(|a, b| a.total_cmp(b)).unwrap_or(&0.0)); println!(" -- Maximum: {:.2}", values.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&0.0)); println!(" -- Mean: {:.2}", statistical::mean(&values)); // <3> println!(" -- Stddev: {:.2}", statistical::standard_deviation(&values, None)); }
-
The
attribute_value
is a closure, which returns the value of one of the instance's attributes. -
Using the provided
attribute_value
closure, we collect all the values for a given attribute. - Information about the values is obtained, either using the built-in iterators or through a call to a statistics library.
Output:
Sepal Length -- Minimum: 4.30 -- Maximum: 7.90 -- Mean: 5.84 -- Stddev: 0.83 Sepal Width -- Minimum: 2.00 -- Maximum: 4.40 -- Mean: 3.05 -- Stddev: 0.43 Petal Length -- Minimum: 1.00 -- Maximum: 6.90 -- Mean: 3.76 -- Stddev: 1.76 Petal Width -- Minimum: 0.10 -- Maximum: 2.50 -- Mean: 1.20 -- Stddev: 0.76
KMeans Clustering
Clustering is straightforward, with the right library!
let samples: Vec<Vec<f64>> = instances.iter() // <1> .map(|instance| vec![instance.sepal_length, instance.sepal_width, instance.petal_length, instance.petal_width]) .collect(); let clustering = clustering::kmeans(3, &samples, 100); // <2> for i in 0..3 { println!("Centre: {:?}", clustering.centroids[i]); // <3> println!("Cluster has: {} points", clustering.membership.iter().filter(|&n| *n == i).count()); // <4> }
- KMeans clusters points as n-dimensional vectors, so we need to convert our instance attributes into vectors.
- Running the algorithm is simple: specify the target number of clusters, samples and maximum number of iterations.
- The centroids are retrieved as a vector.
- Membership is provided as a vector of the respective cluster number, so we can count them using a filter.
Output:
Centre: Centroid([5.901612903225807, 2.748387096774194, 4.393548387096775, 1.4338709677419357]) Cluster has: 62 points Centre: Centroid([5.005999999999999, 3.4180000000000006, 1.464, 0.2439999999999999]) Cluster has: 50 points Centre: Centroid([6.8500000000000005, 3.073684210526315, 5.742105263157893, 2.0710526315789473]) Cluster has: 38 points
Complete Program
Following is the complete "main.rs" program:
use clustering; use csv; use statistical; // Hold the five values making up an IrisInstance struct IrisInstance { sepal_length: f64, sepal_width: f64, petal_length: f64, petal_width: f64, label: String, } impl IrisInstance { fn from_csv_record(record: &csv::StringRecord) -> Option<IrisInstance> { if record.len() == 5 { let sepal_length = record.get(0).expect("expected a number") .parse::<f64>().unwrap_or_default(); let sepal_width = record.get(1).expect("expected a number") .parse::<f64>().unwrap_or_default(); let petal_length = record.get(2).expect("expected a number") .parse::<f64>().unwrap_or_default(); let petal_width = record.get(3).expect("expected a number") .parse::<f64>().unwrap_or_default(); Some(IrisInstance { sepal_length, sepal_width, petal_length, petal_width, label: String::from(record.get(4).expect("expected a string")), }) } else { None } } } fn read_iris_data(filename: &str) -> Vec<IrisInstance> { csv::ReaderBuilder::new() .has_headers(false) .from_path(filename) .expect("Could not read file") .records() .flatten() // remove possible record error .flat_map(|record| IrisInstance::from_csv_record(&record)) .collect::<Vec<IrisInstance>>() } fn print_statistics<F>(instances: &[IrisInstance], attribute_name: &str, attribute_value: F) where F : Fn(&IrisInstance) -> f64, { let values: Vec<f64> = instances.iter().map(attribute_value).collect(); println!("{}", attribute_name); println!(" -- Minimum: {:.2}", values.iter().min_by(|a, b| a.total_cmp(b)).unwrap_or(&0.0)); println!(" -- Maximum: {:.2}", values.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&0.0)); println!(" -- Mean: {:.2}", statistical::mean(&values)); println!(" -- Stddev: {:.2}", statistical::standard_deviation(&values, None)); } fn main() { let instances = read_iris_data("iris.data"); println!("Read {} instances.", instances.len()); print_statistics(&instances, "Sepal Length", |instance| instance.sepal_length); print_statistics(&instances, "Sepal Width", |instance| instance.sepal_width); print_statistics(&instances, "Petal Length", |instance| instance.petal_length); print_statistics(&instances, "Petal Width", |instance| instance.petal_width); // clustering algorithm requires samples as vectors let samples: Vec<Vec<f64>> = instances.iter() .map(|instance| vec![instance.sepal_length, instance.sepal_width, instance.petal_length, instance.petal_width]) .collect(); let clustering = clustering::kmeans(3, &samples, 100); for i in 0..3 { println!("Centre: {:?}", clustering.centroids[i]); println!("Cluster has: {} points", clustering.membership.iter().filter(|&n| *n == i).count()); } }