Core ML and Vision Tutorial: On-device training on iOS

This tutorial introduces you to Core ML and Vision, two cutting-edge iOS frameworks, and how to fine-tune a model on the device. By Christine Abernathy.

Leave a rating/review
Download materials
Save for later
Share
You are currently viewing page 4 of 4 of this article. Click here to view the first page.

Testing the Prediction

Next, in addCanvasForDrawing() add the following right after assigning drawingView:

drawingView.delegate = self

This makes the view controller the drawing view delegate.

Build and run the app and select a photo. Draw on the canvas and verify that the drawing is cleared and the following is logged in the console:

Log results shows no prediction

That’s to be expected. You haven’t added a sticker shortcut yet.

Now walk through the flow of adding a sticker shortcut. After you come back to the view of the selected photo, draw the same shortcut:

Close-up of dandelion with quote

Oops, the sticker still isn’t added! You can check the console log for clues:

Sticker still not added

After a bit of head-scratching, it may notice that your model has no clue about the sticker you’ve added. Time to fix that.

Updating the Model

You update a model by creating an MLUpdateTask. The update task initializer requires the compiled model file, training data and a completion handler. Generally, you want to save your updated model to disk and reload it, so new predictions make use of the latest data.

You’ll start by preparing the training data based on the shortcut drawings.

Recall that you made model predictions by passing in an MLFeatureProvider input. Likewise, you can train a model by passing in a MLFeatureProvider input. You can make batch predictions or train with many inputs by passing in an MLBatchProvider containing multiple feature providers.

First, open DrawingDataStore.swift and replace the Foundation import with the following:

import CoreML

You need this to set up the Core ML training inputs.

Next, add the following method to the extension:

func prepareTrainingData() throws -> MLBatchProvider {
  // 1
  var featureProviders: [MLFeatureProvider] = []
  // 2
  let inputName = "drawing"
  let outputName = "label"
  // 3      
  for drawing in drawings {
    if let drawing = drawing {
      // 4
      let inputValue = drawing.featureValue
      // 5
      let outputValue = MLFeatureValue(string: emoji)
      // 6
      let dataPointFeatures: [String: MLFeatureValue] =
        [inputName: inputValue,
        outputName: outputValue]
      // 7
      if let provider =
        try? MLDictionaryFeatureProvider(
          dictionary: dataPointFeatures) {
        featureProviders.append(provider)
      }
    }
  }
  // 8
  return MLArrayBatchProvider(array: featureProviders)
}

Here’s a step-by-step breakdown of this code:

  1. Initialize an empty array of feature providers.
  2. Define the names for the model training inputs.
  3. Loop through the drawings in the data store.
  4. Wrap the drawing training input in a feature value.
  5. Wrap the emoji training input in a feature value.
  6. Create a data point for the training input. This is a dictionary of the training input names and feature values.
  7. Create a feature provider for the data point and append it to the feature providers array.
  8. Finally, create a batch provider from the array of feature providers.

Now, open UpdatableModel.swift and add the following method to the end of the UpdatableDrawingClassifier extension:

static func updateModel(
  at url: URL,
  with trainingData: MLBatchProvider,
  completionHandler: @escaping (MLUpdateContext) -> Void
) {
  do {
    let updateTask = try MLUpdateTask(
      forModelAt: url,
      trainingData: trainingData,
      configuration: nil,
      completionHandler: completionHandler)
    updateTask.resume()
  } catch {
    print("Couldn't create an MLUpdateTask.")
  }
}

The code creates the update task with the compiled model URL. You also pass in a batch provider with the training data. The call to resume() starts the training and the completion handler is called when training finishes.

Saving the Model

Now, add the following method to the private extension for UpdatableModel:

static func saveUpdatedModel(_ updateContext: MLUpdateContext) {
  // 1
  let updatedModel = updateContext.model
  let fileManager = FileManager.default
  do {
    // 2
    try fileManager.createDirectory(
        at: tempUpdatedModelURL,
        withIntermediateDirectories: true,
        attributes: nil)
    // 3
    try updatedModel.write(to: tempUpdatedModelURL)
    // 4
    _ = try fileManager.replaceItemAt(
      updatedModelURL,
      withItemAt: tempUpdatedModelURL)
    print("Updated model saved to:\n\t\(updatedModelURL)")
  } catch let error {
    print("Could not save updated model to the file system: \(error)")
    return
  }
}

This helper class does the work of saving the updated model. It takes in an MLUpdateContext which has useful info about the training. The method does the following:

  1. First it gets the updated model from memory. This is not the same as the original model.
  2. Then it creates an intermediary folder to save the updated model.
  3. It writes the updated model to a temporary folder.
  4. Finally, it replaces the model folder’s content. Overwriting the existing mlmodelc folder gives errors. The solution is to save to an intermediate folder then copy the contents over.

Performing the Update

Add the following method to the public UpdatableModel extension:

static func updateWith(
  trainingData: MLBatchProvider,
  completionHandler: @escaping () -> Void
) {
  loadModel()
  UpdatableDrawingClassifier.updateModel(
    at: updatedModelURL,
    with: trainingData) { context in
      saveUpdatedModel(context)
      DispatchQueue.main.async { completionHandler() }
  }
}

The code loads the model into memory then calls the update method you defined in its extension. The completion handler saves the updated model then runs this method’s completion handler.

Now, open AddShortcutViewController.swift and replace the savePressed(_:) implementation with the following:

do {
  let trainingData = try drawingDataStore.prepareTrainingData()
  DispatchQueue.global(qos: .userInitiated).async {
    UpdatableModel.updateWith(trainingData: trainingData) {
      DispatchQueue.main.async {
        self.performSegue(
          withIdentifier: "AddShortcutUnwindSegue",
          sender: self)
      }
    }
  }
} catch {
  print("Error updating model", error)
}

Here you’ve put everything together for training. After setting up the training data, you start a background queue to update the model. The update method calls the unwind segue to transition to the main screen.

Build and run the app and go through the steps to create a shortcut.

Add a shortcut screen with heart eyes emoji, there gray rectangles with hearts drawn in them

Verify that when you tap Save the console logs the model update:

Model update logged

Draw the same shortcut on the selected photo and verify that the right emoji shows:

Four screens: flower close up with quote, add a shortcut with hearts, flower close up with heart, flower closeup with heart eye emoji

Congratulations, you machine learning ninja!

Machine Learning ninja warrior

Where to Go From Here?

Download the completed version of the project using the Download Materials button at the top or bottom of this tutorial.

Check out the Machine Learning in iOS video course to learn more about how to train your own models using Create ML and Turi Create. Beginning Machine Learning with Keras & Core ML walks you through how to train a neural network and convert it to Core ML.

Create ML app lets you build, train and deploy machine learning models with no machine learning expertise required. You can also check out official WWDC 2019 sessions on What’s New in Machine Learning and Training Object Detection Models in Create ML

I hope you enjoyed this tutorial! If you have any questions or comments, please join the discussion below.