Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Classifying Iris flowers with Groovy, Deep Lear...

paulking
June 05, 2024
110

Classifying Iris flowers with Groovy, Deep Learning, and GraalVM

This presentation looks at using Groovy to classify Iris flowers using standard classification algorithms and neural networks. The deep learning examples are also compiled using GraalVM.

paulking

June 05, 2024
Tweet

Transcript

  1. Classifying Iris flowers with Groovy, Deep Learning, and GraalVM Dr

    Paul King, VP Apache Groovy & Distinguished Engineer Object Computing Twitter/X | Mastodon : Apache Groovy: Repo: Slides: @ApacheGroovy | @[email protected] https://groovy.apache.org/ https://groovy-lang.org/ https://github.com/paulk-asert/groovy-data-science https://speakerdeck.com/paulk/groovy-iris
  2. Why use Groovy in 2024? It’s like a super version

    of Java: • Simpler scripting: more powerful yet more concise • Extension methods: 2000+ enhancements to Java classes for a great out-of-the box experience (batteries included) • Flexible Typing: from dynamic duck-typing (terse code) to extensible stronger-than-Java static typing (better checking) • Improved OO & Functional Features: from traits (more powerful and flexible OO designs) to tail recursion and memorizing/partial application of pure functions • AST transforms: 10s of lines instead of 100/1000s of lines • Java Features Earlier: recent features on older JDKs
  3. Hello world script import java.util.List; import java.util.function.Predicate; void main() {

    var pets = List.of("bird", "cat", "dog"); Predicate<String> stringSize3 = s -> s.length() == 3; System.out.println(pets.stream().filter(stringSize3).toList()); } [cat, dog] var pets = ["bird", "cat", "dog"] var size3 = { it.size() == 3 } println pets.findAll(size3) JDK8+ JDK23 with preview enabled
  4. Hello world test import org.junit.Test; import java.util.List; import java.util.function.Predicate; import

    static org.junit.Assert.assertEquals; public class HelloTest { @Test public void testSize() { var pets = List.of("bird", "cat", "dog"); var nums = List.of(List.of(1, 2), List.of(3, 4, 5), List.of(6)); Predicate<String> stringSize3 = s -> s.length() == 3; Predicate<List<Integer>> listSize3 = s -> s.size() == 3; assertEquals(pets.stream().filter(stringSize3).count(), 2L); assertEquals(nums.stream().filter(listSize3).count(), 1L); } }
  5. Hello world test import org.junit.Test; import java.util.List; import java.util.function.Predicate; import

    static org.junit.Assert.assertEquals; public class HelloTest { @Test public void testSize() { var pets = List.of("bird", "cat", "dog"); var nums = List.of(List.of(1, 2), List.of(3, 4, 5), List.of(6)); Predicate<String> stringSize3 = s -> s.length() == 3; Predicate<List<Integer>> listSize3 = s -> s.size() == 3; assertEquals(pets.stream().filter(stringSize3).count(), 2L); assertEquals(nums.stream().filter(listSize3).count(), 1L); } }
  6. Hello world test import org.junit.Test; import java.util.List; import java.util.function.Predicate; import

    static org.junit.Assert.assertEquals; public class HelloTest { @Test public void testSize() { var pets = List.of("bird", "cat", "dog"); var nums = List.of(List.of(1, 2), List.of(3, 4, 5), List.of(6)); Predicate<String> stringSize3 = s -> s.length() == 3; Predicate<List<Integer>> listSize3 = s -> s.size() == 3; assertEquals(pets.stream().filter(stringSize3).count(), 2L); assertEquals(nums.stream().filter(listSize3).count(), 1L); } } var pets = ["bird", "cat", "dog"] var nums = [1..2, 3..5, 6..6] var size3 = { it.size() == 3 } assert pets.count(size3) == 2 && nums.count(size3) == 1
  7. Operator Overloading jshell> import org.apache.commons.math3.linear.MatrixUtils jshell> double[][] d1 = {

    {10d, 0d}, {0d, 10d}} d1 ==> double[2][] { double[2] { 10.0, 0.0 }, double[2] { 0.0, 10.0 } } jshell> var m1 = MatrixUtils.createRealMatrix(d1) m1 ==> Array2DRowRealMatrix{{10.0,0.0},{0.0,10.0}} jshell> double[][] d2 = { {-1d, 1d}, {1d, -1d}} d2 ==> double[2][] { double[2] { -1.0, 1.0 }, double[2] { 1.0, -1.0 } } jshell> var m2 = MatrixUtils.createRealMatrix(d2) m2 ==> Array2DRowRealMatrix{{-1.0,1.0},{1.0,-1.0}} jshell> System.out.println(m1.multiply(m2.power(2))) Array2DRowRealMatrix{{20.0,-20.0},{-20.0,20.0}}
  8. Operator Overloading (plus other features) jshell> import org.apache.commons.math3.linear.MatrixUtils jshell> double[][]

    d1 = { {10d, 0d}, {0d, 10d}} d1 ==> double[2][] { double[2] { 10.0, 0.0 }, double[2] { 0.0, 10.0 } } jshell> var m1 = MatrixUtils.createRealMatrix(d1) m1 ==> Array2DRowRealMatrix{{10.0,0.0},{0.0,10.0}} jshell> double[][] d2 = { {-1d, 1d}, {1d, -1d}} d2 ==> double[2][] { double[2] { -1.0, 1.0 }, double[2] { 1.0, -1.0 } } jshell> var m2 = MatrixUtils.createRealMatrix(d2) m2 ==> Array2DRowRealMatrix{{-1.0,1.0},{1.0,-1.0}} jshell> System.out.println(m1.multiply(m2.power(2))) Array2DRowRealMatrix{{20.0,-20.0},{-20.0,20.0}}
  9. IntComparator maxAbs = (i, j) -> i.abs() <=> j.abs() nums.max()

    nums.max(maxAbs) Primitive array extension methods Comparator<Integer> maxAbs = Comparator.<Integer>comparingInt(Math::abs) nums.intStream().max().getAsInt() nums.stream().max(maxAbs).get() public class JavaStreamsMax { private static Comparator<Integer> comparator = Comparator.comparingInt(Math::abs); public static int max(int[] nums) { return Arrays.stream(nums).max().getAsInt(); } public static int maxAbs(int[] nums) { return Arrays.stream(nums).boxed().max(comparator).get(); } } int[] numbers = {10, 20, 15, 30, 5};
  10. IntComparator maxAbs = (i, j) -> i.abs() <=> j.abs() nums.max()

    nums.max(maxAbs) Comparator<Integer> maxAbs = Comparator.<Integer>comparingInt(Math::abs) nums.intStream().max().getAsInt() nums.stream().max(maxAbs).get() public class JavaStreamsMax { private static Comparator<Integer> comparator = Comparator.comparingInt(Math::abs); public static int max(int[] nums) { return Arrays.stream(nums).max().getAsInt(); } public static int maxAbs(int[] nums) { return Arrays.stream(nums).boxed().max(comparator).get(); } } Better Primitive array extension methods int[] numbers = {10, 20, 15, 30, 5};
  11. AST Transformations // imports not shown public class Book {

    private String $to$string; private int $hash$code; private final List<String> authors; private final String title; private final Date publicationDate; private static final java.util.Comparator this$TitleComparator; private static final java.util.Comparator this$PublicationDateComparator; public Book(List<String> authors, String title, Date publicationDate) { if (authors == null) { this.authors = null; } else { if (authors instanceof Cloneable) { List<String> authorsCopy = (List<String>) ((ArrayList<?>) authors).clone(); this.authors = (List<String>) (authorsCopy instanceof SortedSet ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof SortedMap ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof Set ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof Map ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof List ? DefaultGroovyMethods.asImmutable(authorsCopy) : DefaultGroovyMethods.asImmutable(authorsCopy)); } else { this.authors = (List<String>) (authors instanceof SortedSet ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof SortedMap ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof Set ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof Map ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof List ? DefaultGroovyMethods.asImmutable(authors) : DefaultGroovyMethods.asImmutable(authors)); } } this.title = title; if (publicationDate == null) { this.publicationDate = null; } else { this.publicationDate = (Date) publicationDate.clone(); } } public Book(Map args) { if ( args == null) { args = new HashMap(); } ImmutableASTTransformation.checkPropNames(this, args); if (args.containsKey("authors")) { if ( args.get("authors") == null) { this .authors = null; } else { if (args.get("authors") instanceof Cloneable) { List<String> authorsCopy = (List<String>) ((ArrayList<?>) args.get("authors")).clone(); this.authors = (List<String>) (authorsCopy instanceof SortedSet ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof SortedMap ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof Set ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof Map ? DefaultGroovyMethods.asImmutable(authorsCopy) : authorsCopy instanceof List ? DefaultGroovyMethods.asImmutable(authorsCopy) : DefaultGroovyMethods.asImmutable(authorsCopy)); } else { List<String> authors = (List<String>) args.get("authors"); this.authors = (List<String>) (authors instanceof SortedSet ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof SortedMap ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof Set ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof Map ? DefaultGroovyMethods.asImmutable(authors) : authors instanceof List ? DefaultGroovyMethods.asImmutable(authors) : DefaultGroovyMethods.asImmutable(authors)); } } } else { this .authors = null; } if (args.containsKey("title")) {this .title = (String) args.get("title"); } else { this .title = null;} if (args.containsKey("publicationDate")) { if (args.get("publicationDate") == null) { this.publicationDate = null; } else { this.publicationDate = (Date) ((Date) args.get("publicationDate")).clone(); } } else {this.publicationDate = null; } } … public Book() { this (new HashMap()); } public int compareTo(Book other) { if (this == other) { return 0; } Integer value = 0 value = this .title <=> other .title if ( value != 0) { return value } value = this .publicationDate <=> other .publicationDate if ( value != 0) { return value } return 0 } public static Comparator comparatorByTitle() { return this$TitleComparator; } public static Comparator comparatorByPublicationDate() { return this$PublicationDateComparator; } public String toString() { StringBuilder _result = new StringBuilder(); boolean $toStringFirst= true; _result.append("Book("); if ($toStringFirst) { $toStringFirst = false; } else { _result.append(", "); } _result.append(InvokerHelper.toString(this.getAuthors())); if ($toStringFirst) { $toStringFirst = false; } else { _result.append(", "); } _result.append(InvokerHelper.toString(this.getTitle())); if ($toStringFirst) { $toStringFirst = false; } else { _result.append(", "); } _result.append(InvokerHelper.toString(this.getPublicationDate())); _result.append(")"); if ($to$string == null) { $to$string = _result.toString(); } return $to$string; } public int hashCode() { if ( $hash$code == 0) { int _result = HashCodeHelper.initHash(); if (!(this.getAuthors().equals(this))) { _result = HashCodeHelper.updateHash(_result, this.getAuthors()); } if (!(this.getTitle().equals(this))) { _result = HashCodeHelper.updateHash(_result, this.getTitle()); } if (!(this.getPublicationDate().equals(this))) { _result = HashCodeHelper.updateHash(_result, this.getPublicationDate()); } $hash$code = (int) _result; } return $hash$code; } public boolean canEqual(Object other) { return other instanceof Book; } … public boolean equals(Object other) { if ( other == null) { return false; } if (this == other) { return true; } if (!( other instanceof Book)) { return false; } Book otherTyped = (Book) other; if (!(otherTyped.canEqual( this ))) { return false; } if (!(this.getAuthors() == otherTyped.getAuthors())) { return false; } if (!(this.getTitle().equals(otherTyped.getTitle()))) { return false; } if (!(this.getPublicationDate().equals(otherTyped.getPublicationDate()))) { return false; } return true; } public final Book copyWith(Map map) { if (map == null || map.size() == 0) { return this; } Boolean dirty = false; HashMap construct = new HashMap(); if (map.containsKey("authors")) { Object newValue = map.get("authors"); Object oldValue = this.getAuthors(); if (newValue != oldValue) { oldValue = newValue; dirty = true; } construct.put("authors", oldValue); } else { construct.put("authors", this.getAuthors()); } if (map.containsKey("title")) { Object newValue = map.get("title"); Object oldValue = this.getTitle(); if (newValue != oldValue) { oldValue = newValue; dirty = true; } construct.put("title", oldValue); } else { construct.put("title", this.getTitle()); } if (map.containsKey("publicationDate")) { Object newValue = map.get("publicationDate"); Object oldValue = this.getPublicationDate(); if (newValue != oldValue) { oldValue = newValue; dirty = true; } construct.put("publicationDate", oldValue); } else { construct.put("publicationDate", this.getPublicationDate()); } return dirty == true ? new Book(construct) : this; } public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(authors); out.writeObject(title); out.writeObject(publicationDate); } public void readExternal(ObjectInput oin) throws IOException, ClassNotFoundException { authors = (List) oin.readObject(); title = (String) oin.readObject(); publicationDate = (Date) oin.readObject(); } … static { this$TitleComparator = new Book$TitleComparator(); this$PublicationDateComparator = new Book$PublicationDateComparator(); } public String getAuthors(int index) { return authors.get(index); } public List<String> getAuthors() { return authors; } public final String getTitle() { return title; } public final Date getPublicationDate() { if (publicationDate == null) { return publicationDate; } else { return (Date) publicationDate.clone(); } } public int compare(java.lang.Object param0, java.lang.Object param1) { return -1; } private static class Book$TitleComparator extends AbstractComparator<Book> { public Book$TitleComparator() { } public int compare(Book arg0, Book arg1) { if (arg0 == arg1) { return 0; } if (arg0 != null && arg1 == null) { return -1; } if (arg0 == null && arg1 != null) { return 1; } return arg0.title <=> arg1.title; } public int compare(java.lang.Object param0, java.lang.Object param1) { return -1; } } private static class Book$PublicationDateComparator extends AbstractComparator<Book> { public Book$PublicationDateComparator() { } public int compare(Book arg0, Book arg1) { if ( arg0 == arg1 ) { return 0; } if ( arg0 != null && arg1 == null) { return -1; } if ( arg0 == null && arg1 != null) { return 1; } return arg0 .publicationDate <=> arg1 .publicationDate; } public int compare(java.lang.Object param0, java.lang.Object param1) { return -1; } } } @Immutable(copyWith = true) @Sortable(excludes = 'authors') @AutoExternalize class Book { @IndexedProperty List<String> authors String title Date publicationDate } 10 lines of Groovy or 600 lines of Java
  12. Classification Overview Classification: • Predicting class of some data Algorithms:

    • Logistic Regression, Naïve Bayes, Stochastic Gradient Descent, K-Nearest Neighbors, Decision Tree, Random Forest, Support Vector Machine Aspects: • Over/underfitting • Ensemble • Confusion Applications: • Image/speech recognition • Spam filtering • Medical diagnosis • Fraud detection • Customer behaviour prediction
  13. Classification Overview Classification: • Predicting class of some data Algorithms:

    • Logistic Regression, Naïve Bayes, Stochastic Gradient Descent, K-Nearest Neighbors, Decision Tree, Random Forest, Support Vector Machine Aspects: • Over/underfitting • Ensemble • Confusion Applications: • Image/speech recognition • Spam filtering • Medical diagnosis • Fraud detection • Customer behaviour prediction
  14. Case Study: classification of Iris flowers British statistician & biologist

    Ronald Fisher 1936 paper: “The use of multiple measurements in taxonomic problems as an example of linear discriminant analysis” 150 samples, 50 each of three species of Iris: • setosa • versicolor • virginica Four features measured for each sample: • sepal length • sepal width • petal length • petal width https://en.wikipedia.org/wiki/Iris_flower_data_set https://archive.ics.uci.edu/ml/datasets/Iris setosa versicolor virginica sepal petal
  15. Case Study: classification of Iris flowers British statistician & biologist

    Ronald Fisher 1936 paper: “The use of multiple measurements in taxonomic problems as an example of linear discriminant analysis” 150 samples, 50 each of three species of Iris: • setosa • versicolor • virginica Four features measured for each sample: • sepal length • sepal width • petal length • petal width https://en.wikipedia.org/wiki/Iris_flower_data_set https://archive.ics.uci.edu/ml/datasets/Iris
  16. Iris flower data – Jupyter Sepal length,Sepal width,Petal length,Petal width,Class

    5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa ... 7.0,3.2,4.7,1.4,Iris-versicolor 6.4,3.2,4.5,1.5,Iris-versicolor ...
  17. Iris flower data – Weka Naïve Bayes def file =

    getClass().classLoader.getResource('iris_data.csv').file as File def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] def loader = new CSVLoader(file: file) def model = new NaiveBayes() def allInstances = loader.dataSet allInstances.classIndex = 4 model.buildClassifier(allInstances) double[] actual = allInstances*.value(4) double[] predicted = allInstances.collect(model::classifyInstance) double[] petalW = allInstances*.value(2) double[] petalL = allInstances*.value(3) def indices = actual.indices
  18. def file = getClass().classLoader.getResource('iris_data.csv').file as File def species = ['Iris-setosa',

    'Iris-versicolor', 'Iris-virginica'] def loader = new CSVLoader(file: file) def model = new NaiveBayes() def allInstances = loader.dataSet allInstances.classIndex = 4 model.buildClassifier(allInstances) double[] actual = allInstances*.value(4) double[] predicted = allInstances.collect(model::classifyInstance) double[] petalW = allInstances*.value(2) double[] petalL = allInstances*.value(3) def indices = actual.indices Iris flower data – Weka Naïve Bayes def chart = new XYChartBuilder().width(900).height(450). title("Species").xAxisTitle("Petal length").yAxisTitle("Petal width").build() species.eachWithIndex{ String name, int i -> def groups = indices.findAll{ predicted[it] == i }.groupBy{ actual[it] == i } Collection found = groups[true] ?: [] Collection errors = groups[false] ?: [] println "$name: ${found.size()} correct, ${errors.size()} incorrect" chart.addSeries("$name correct", petalW[found], petalL[found]).with { XYSeriesRenderStyle = Scatter } if (errors) { chart.addSeries("$name incorrect", petalW[errors], petalL[errors]).with { XYSeriesRenderStyle = Scatter } } } new SwingWrapper(chart).displayChart()
  19. def chart = new XYChartBuilder().width(900).height(450). title("Species").xAxisTitle("Petal length").yAxisTitle("Petal width").build() species.eachWithIndex{ String

    name, int i -> def groups = indices.findAll{ predicted[it] == i }.groupBy{ actual[it] == i } Collection found = groups[true] ?: [] Collection errors = groups[false] ?: [] println "$name: ${found.size()} correct, ${errors.size()} incorrect" chart.addSeries("$name correct", petalW[found], petalL[found]).with { XYSeriesRenderStyle = Scatter } if (errors) { chart.addSeries("$name incorrect", petalW[errors], petalL[errors]).with { XYSeriesRenderStyle = Scatter } } } new SwingWrapper(chart).displayChart() Iris flower data – Weka Naïve Bayes def file = getClass().classLoader.getResource('iris_data.csv').file as File def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] def loader = new CSVLoader(file: file) def model = new NaiveBayes() def allInstances = loader.dataSet allInstances.classIndex = 4 model.buildClassifier(allInstances) double[] actual = allInstances*.value(4) double[] predicted = allInstances.collect(model::classifyInstance) double[] petalW = allInstances*.value(2) double[] petalL = allInstances*.value(3) def indices = actual.indices Iris-setosa: 50 correct, 0 incorrect Iris-versicolor: 48 correct, 4 incorrect Iris-virginica: 46 correct, 2 incorrect
  20. Iris flower data – Weka Logistic Regression def file =

    getClass().classLoader.getResource('iris_data.csv').file as File def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] def loader = new CSVLoader(file: file) def model = new SimpleLogistic() def allInstances = loader.dataSet allInstances.classIndex = 4 model.buildClassifier(allInstances) double[] actual = allInstances*.value(4) double[] predicted = allInstances.collect(model::classifyInstance) double[] petalW = allInstances*.value(2) double[] petalL = allInstances*.value(3) def indices = actual.indices def chart = new XYChartBuilder().width(900).height(450). title("Species").xAxisTitle("Petal length").yAxisTitle("Petal width").build() species.eachWithIndex{ String name, int i -> def groups = indices.findAll{ predicted[it] == i }.groupBy{ actual[it] == i } Collection found = groups[true] ?: [] Collection errors = groups[false] ?: [] println "$name: ${found.size()} correct, ${errors.size()} incorrect" chart.addSeries("$name correct", petalW[found], petalL[found]).with { XYSeriesRenderStyle = Scatter } if (errors) { chart.addSeries("$name incorrect", petalW[errors], petalL[errors]).with { XYSeriesRenderStyle = Scatter } } } new SwingWrapper(chart).displayChart() Iris-setosa: 50 correct, 0 incorrect Iris-versicolor: 49 correct, 1 incorrect Iris-virginica: 49 correct, 1 incorrect
  21. Iris flower data – Weka J48 Decision Tree def file

    = getClass().classLoader.getResource('iris_data.csv').file as File def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] def loader = new CSVLoader(file: file) def model = new J48() def allInstances = loader.dataSet allInstances.classIndex = 4 model.buildClassifier(allInstances) double[] actual = allInstances.collect{ it.value(4) } double[] predicted = allInstances.collect{ model.classifyInstance(it) } double[] petalW = allInstances.collect{ it.value(0) } double[] petalL = allInstances.collect{ it.value(1) } def indices = actual.indices def chart = new XYChartBuilder().width(900).height(450). title("Species").xAxisTitle("Petal length").yAxisTitle("Petal width").build() species.eachWithIndex{ String name, int i -> def groups = indices.findAll{ predicted[it] == i }.groupBy{ actual[it] == i } Collection found = groups[true] ?: [] Collection errors = groups[false] ?: [] println "$name: ${found.size()} correct, ${errors.size()} incorrect" chart.addSeries("$name correct", petalW[found], petalL[found]).with { XYSeriesRenderStyle = Scatter } if (errors) { chart.addSeries("$name incorrect", petalW[errors], petalL[errors]).with { XYSeriesRenderStyle = Scatter } } } new SwingWrapper(chart).displayChart() Petal width <= 0.6: Iris-setosa (50.0) Petal width > 0.6 | Petal width <= 1.7 | | Petal length <= 4.9: Iris-versicolor (48.0/1.0) | | Petal length > 4.9 | | | Petal width <= 1.5: Iris-virginica (3.0) | | | Petal width > 1.5: Iris-versicolor (3.0/1.0) | Petal width > 1.7: Iris-virginica (46.0/1.0) Number of Leaves : 5 Size of the tree : 9 Iris-setosa: 50 correct, 0 incorrect Iris-versicolor: 49 correct, 2 incorrect Iris-virginica: 48 correct, 1 incorrect
  22. Iris flower data – Smile KNN def features = ['Sepal

    length', 'Sepal width', 'Petal length', 'Petal width'] def species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] def file = getClass().classLoader.getResource('iris_data.csv').file Table table = Table.read().csv(file) def helper = new TablesawUtil(file) … (0..<features.size()).each { println table.summarize(features[it], mean, min, max).by('Class') } def dataFrame = table.smile().toDataFrame() def featureCols = dataFrame.drop('Class').toArray() def classNames = dataFrame.column('Class').toStringArray() int[] classes = classNames.collect{species.indexOf(it) } … … def knn = KNN.fit(featureCols, classes, 3) def predictions = knn.predict(featureCols) println """ Confusion matrix: ${ConfusionMatrix.of(classes, predictions)} """ table = table.addColumns(StringColumn.create('Result', predictions.indexed().collect{ idx, predictedClass -> def (actual, predicted) = [classNames[idx], species[predictedClass]] actual == predicted ? predicted : "$predicted/$actual".toString() })) def title = 'Petal width vs length with predicted[/actual] class' helper.show(ScatterPlot.create(title, table, 'Petal width', 'Petal length', 'Result'), 'KNNClassification') // use cross validation to get accuracy CrossValidation.classification(10, featureCols, classes, (x, y) -> KNN.fit(x, y, 3)).with { printf 'Accuracy: %.2f%% +/- %.2f\n', 100 * avg.accuracy, 100 * sd.accuracy }
  23. KNN Classification • Calculate distance to other data points •

    Find closest K points • Regression: Average values • Classification: Label with majority class
  24. KNN Classification • Calculate distance to other data points •

    Find closest K points • Regression: Average values • Classification: Label with majority class
  25. KNN Classification • Calculate distance to other data points •

    Find closest K points • Regression: Average values • Classification: Label with majority class
  26. Feature stats by species: iris_data.csv summary Class | Mean [Sepal

    length] | Min [Sepal length] | Max [Sepal length] | ----------------------------------------------------------------------------------------- Iris-setosa | 5.006 | 4.3 | 5.8 | Iris-versicolor | 5.936 | 4.9 | 7 | Iris-virginica | 6.588 | 4.9 | 7.9 | iris_data.csv summary Class | Mean [Sepal width] | Min [Sepal width] | Max [Sepal width] | -------------------------------------------------------------------------------------- Iris-setosa | 3.418 | 2.3 | 4.4 | Iris-versicolor | 2.77 | 2 | 3.4 | Iris-virginica | 2.974 | 2.2 | 3.8 | iris_data.csv summary Class | Mean [Petal length] | Min [Petal length] | Max [Petal length] | ----------------------------------------------------------------------------------------- Iris-setosa | 1.464 | 1 | 1.9 | Iris-versicolor | 4.26 | 3 | 5.1 | Iris-virginica | 5.552 | 4.5 | 6.9 | iris_data.csv summary Class | Mean [Petal width] | Min [Petal width] | Max [Petal width] | -------------------------------------------------------------------------------------- Iris-setosa | 0.244 | 0.1 | 0.6 | Iris-versicolor | 1.326 | 1 | 1.8 | Iris-virginica | 2.026 | 1.4 | 2.5 | Confusion matrix: ROW=truth and COL=predicted class 0 | 50 | 0 | 0 | class 1 | 0 | 47 | 3 | class 2 | 0 | 3 | 47 | Iris flower data – Smile KNN
  27. Iris flower data – wekaDeeplearning4j WekaPackageManager.loadPackages(true) def file = getClass().classLoader.getResource('iris_data.csv').file

    as File def loader = new CSVLoader(file: file) def data = loader.dataSet data.classIndex = 4 def options = Utils.splitOptions("-S 1 -numEpochs 10 -layer \"weka.dl4j.layers.OutputLayer -activation weka.dl4j.activations.ActivationSoftmax \ -lossFn weka.dl4j.lossfunctions.LossMCXENT\"") AbstractClassifier myClassifier = Utils.forName(AbstractClassifier, "weka.classifiers.functions.Dl4jMlpClassifier", options) // Stratify and split Random rand = new Random(0) Instances randData = new Instances(data) randData.randomize(rand) randData.stratify(3) Instances train = randData.trainCV(3, 0) Instances test = randData.testCV(3, 0) // Build the classifier on the training data myClassifier.buildClassifier(train) // Evaluate the model on test data Evaluation eval = new Evaluation(test) eval.evaluateModel(myClassifier, test) println eval.toSummaryString() println eval.toMatrixString()
  28. WekaPackageManager.loadPackages(true) def file = getClass().classLoader.getResource('iris_data.csv').file as File def loader =

    new CSVLoader(file: file) def data = loader.dataSet data.classIndex = 4 def options = Utils.splitOptions("-S 1 -numEpochs 10 -layer \"weka.dl4j.layers.OutputLayer -activation weka.dl4j.activations.ActivationSoftmax \ -lossFn weka.dl4j.lossfunctions.LossMCXENT\"") AbstractClassifier myClassifier = Utils.forName(AbstractClassifier, "weka.classifiers.functions.Dl4jMlpClassifier", options) // Stratify and split Random rand = new Random(0) Instances randData = new Instances(data) randData.randomize(rand) randData.stratify(3) Instances train = randData.trainCV(3, 0) Instances test = randData.testCV(3, 0) // Build the classifier on the training data myClassifier.buildClassifier(train) // Evaluate the model on test data Evaluation eval = new Evaluation(test) eval.evaluateModel(myClassifier, test) println eval.toSummaryString() println eval.toMatrixString() Iris flower data – wekaDeeplearning4j [main] INFO org.deeplearning4j.nn.graph.ComputationGraph - Starting ComputationGraph with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE] Training Dl4jMlpClassifier...: [] ETA: 00:00:00[INFO ] 00:03:31.035 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [1/10] took 00:00:00.670 Training Dl4jMlpClassifier...: [====== ] ETA: 00:00:06[INFO ] 00:03:31.152 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [2/10] took 00:00:00.113 Training Dl4jMlpClassifier...: [============ ] ETA: 00:00:03[INFO ] 00:03:31.244 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [3/10] took 00:00:00.090 Training Dl4jMlpClassifier...: [================== ] ETA: 00:00:02[INFO ] 00:03:31.325 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [4/10] took 00:00:00.079 Training Dl4jMlpClassifier...: [======================== ] ETA: 00:00:01[INFO ] 00:03:31.470 [main] weka.dl4j.listener.EpochListener - Epoch [5/10] Train Set: Loss: 0.510342 [INFO ] 00:03:31.470 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [5/10] took 00:00:00.144 Training Dl4jMlpClassifier...: [============================== ] ETA: 00:00:01[INFO ] 00:03:31.546 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [6/10] took 00:00:00.073 Training Dl4jMlpClassifier...: [==================================== ] ETA: 00:00:00[INFO ] 00:03:31.611 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [7/10] took 00:00:00.063 Training Dl4jMlpClassifier...: [========================================== ] ETA: 00:00:00[INFO ] 00:03:31.714 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [8/10] took 00:00:00.101 Training Dl4jMlpClassifier...: [================================================ ] ETA: 00:00:00[INFO ] 00:03:31.790 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [9/10] took 00:00:00.074 Training Dl4jMlpClassifier...: [====================================================== ] ETA: 00:00:00[INFO ] 00:03:31.882 [main] weka.dl4j.listener.EpochListener - Epoch [10/10] Train Set: Loss: 0.286469 …
  29. WekaPackageManager.loadPackages(true) def file = getClass().classLoader.getResource('iris_data.csv').file as File def loader =

    new CSVLoader(file: file) def data = loader.dataSet data.classIndex = 4 def options = Utils.splitOptions("-S 1 -numEpochs 10 -layer \"weka.dl4j.layers.OutputLayer -activation weka.dl4j.activations.ActivationSoftmax \ -lossFn weka.dl4j.lossfunctions.LossMCXENT\"") AbstractClassifier myClassifier = Utils.forName(AbstractClassifier, "weka.classifiers.functions.Dl4jMlpClassifier", options) // Stratify and split Random rand = new Random(0) Instances randData = new Instances(data) randData.randomize(rand) randData.stratify(3) Instances train = randData.trainCV(3, 0) Instances test = randData.testCV(3, 0) // Build the classifier on the training data myClassifier.buildClassifier(train) // Evaluate the model on test data Evaluation eval = new Evaluation(test) eval.evaluateModel(myClassifier, test) println eval.toSummaryString() println eval.toMatrixString() Iris flower data – wekaDeeplearning4j … [INFO ] 00:03:31.883 [main] weka.classifiers.functions.Dl4jMlpClassifier - Epoch [10/10] took 00:00:00.091 Training Dl4jMlpClassifier...: [============================================================] ETA: 00:00:00 Done! Correctly Classified Instances 40 80 % Incorrectly Classified Instances 10 20 % Kappa statistic 0.701 Mean absolute error 0.2542 Root mean squared error 0.3188 Relative absolute error 57.2154 % Root relative squared error 67.6336 % Total Number of Instances 50 === Confusion Matrix === a b c <-- classified as 17 0 0 | a = Iris-setosa 0 9 8 | b = Iris-versicolor 0 2 14 | c = Iris-virginica BUILD SUCCESSFUL in 22s
  30. Iris flower data – Deep Netts String[] cols = ['Sepal

    length', 'Sepal width', 'Petal length', 'Petal width'] String[] species = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] int numInputs = cols.size() int numOutputs = species.size() // Deep Netts readCsv assumes normalized data, so we roll our own var dataSet = new TabularDataSet(numInputs, numOutputs).tap{ columnNames = cols + species } var data = getClass().classLoader.getResource('iris_data.csv').readLines()*.split(',') data[1..-1].each { dataSet.add(new TabularDataSet.Item(it[0..3]*.toFloat() as float[], oneHotEncode(it[4], species))) } scaleMax(dataSet) def (train, test) = dataSet.split(0.7, 0.3) // 70/30% split …
  31. Iris flower data – Deep Netts … var neuralNet =

    FeedForwardNetwork.builder() .addInputLayer(numInputs) .addFullyConnectedLayer(5, ActivationType.TANH) .addOutputLayer(numOutputs, ActivationType.SOFTMAX) .lossFunction(LossType.CROSS_ENTROPY) .randomSeed(456) .build() neuralNet.trainer.with { maxError = 0.04f learningRate = 0.01f momentum = 0.9f optimizer = OptimizerType.MOMENTUM } neuralNet.train(train) new ClassifierEvaluator().with { println "CLASSIFIER EVALUATION METRICS\n${evaluate(neuralNet, test)}" println "CONFUSION MATRIX\n$confusionMatrix" }
  32. … var neuralNet = FeedForwardNetwork.builder() .addInputLayer(numInputs) .addFullyConnectedLayer(5, ActivationType.TANH) .addOutputLayer(numOutputs, ActivationType.SOFTMAX)

    .lossFunction(LossType.CROSS_ENTROPY) .randomSeed(456) .build() neuralNet.trainer.with { maxError = 0.04f learningRate = 0.01f momentum = 0.9f optimizer = OptimizerType.MOMENTUM } neuralNet.train(train) new ClassifierEvaluator().with { println "CLASSIFIER EVALUATION METRICS\n${evaluate(neuralNet, test)}" println "CONFUSION MATRIX\n$confusionMatrix" } Iris flower data – Deep Netts $ time groovy -cp "build/lib/*" Iris.groovy 16:49:27.089 [main] INFO deepnetts.core.DeepNetts - ------------------------------------------------------------------ 16:49:27.091 [main] INFO deepnetts.core.DeepNetts - TRAINING NEURAL NETWORK 16:49:27.091 [main] INFO deepnetts.core.DeepNetts - ------------------------------------------------------------------ 16:49:27.100 [main] INFO deepnetts.core.DeepNetts - Epoch:1, Time:6ms, TrainError:0.8584314, TrainErrorChange:0.858431 16:49:27.103 [main] INFO deepnetts.core.DeepNetts - Epoch:2, Time:3ms, TrainError:0.52278274, TrainErrorChange:-0.3356 ... 16:49:27.911 [main] INFO deepnetts.core.DeepNetts - Epoch:3031, Time:0ms, TrainError:0.029988592, TrainErrorChange:-0. TRAINING COMPLETED 16:49:27.911 [main] INFO deepnetts.core.DeepNetts - Total Training Time: 820ms 16:49:27.911 [main] INFO deepnetts.core.DeepNetts - ------------------------------------------------------------------ CLASSIFIER EVALUATION METRICS Accuracy: 0.95681506 (How often is classifier correct in total) Precision: 0.974359 (How often is classifier correct when it gives positive prediction) F1Score: 0.974359 (Harmonic average (balance) of precision and recall) Recall: 0.974359 (When it is actually positive class, how often does it give positive prediction) CONFUSION MATRIX none Iris-setosaIris-versicolor Iris-virginica none 0 0 0 0 Iris-setosa 0 14 0 0 Iris-versicolor 0 0 18 1 Iris-virginica 0 0 0 12 real 0m3.160s user 0m10.156s sys 0m0.483s
  33. Iris flower data – Deep Netts with GraalVM groovyc -cp

    "build/lib/*" --compile-static iris.groovy java -agentlib:native-image-agent=config-output-dir=conf/ -cp ".:build/lib/*" iris
  34. Iris flower data – Deep Netts with GraalVM native-image --report-unsupported-elements-at-runtime

    \ --initialize-at-run-time=groovy.grape.GrapeIvy,deepnetts.net.weights.RandomWeights \ --initialize-at-build-time \ --no-fallback \ -H:ConfigurationFileDirectories=conf/ \ -cp ".:build/lib/*" \ -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN \ iris
  35. Iris flower data – Deep Netts with GraalVM $ time

    ./iris CLASSIFIER EVALUATION METRICS Accuracy: 0.93460923 (How often is classifier correct in total) Precision: 0.96491224 (How often is classifier correct when it gives positive prediction) F1Score: 0.96491224 (Harmonic average (balance) of precision and recall) Recall: 0.96491224 (When it is actually positive class, how often does it give positive prediction) CONFUSION MATRIX none Iris-setosaIris-versicolor Iris-virginica none 0 0 0 0 Iris-setosa 0 21 0 0 Iris-versicolor 0 0 20 2 Iris-virginica 0 0 0 17 real 0m0.131s user 0m0.096s sys 0m0.029s
  36. Questions? Twitter/X | Mastodon : Apache Groovy: Repo: Slides: @ApacheGroovy

    | @[email protected] https://groovy.apache.org/ https://groovy-lang.org/ https://github.com/paulk-asert/groovy-data-science https://speakerdeck.com/paulk/groovy-iris