版本记录
版本号 | 时间 |
---|---|
V1.0 | 2020.05.16 星期六 |
前言
目前世界上科技界的所有大佬一致认为人工智能是下一代科技革命,苹果作为科技界的巨头,当然也会紧跟新的科技革命的步伐,其中ios API 就新出了一个框架
Core ML
。ML是Machine Learning
的缩写,也就是机器学习,这正是现在很火的一个技术,它也是人工智能最核心的内容。感兴趣的可以看我写的下面几篇。
1. Core ML框架详细解析(一) —— Core ML基本概览
2. Core ML框架详细解析(二) —— 获取模型并集成到APP中
3. Core ML框架详细解析(三) —— 利用Vision和Core ML对图像进行分类
4. Core ML框架详细解析(四) —— 将训练模型转化为Core ML
5. Core ML框架详细解析(五) —— 一个Core ML简单示例(一)
6. Core ML框架详细解析(六) —— 一个Core ML简单示例(二)
7. Core ML框架详细解析(七) —— 减少Core ML应用程序的大小(一)
8. Core ML框架详细解析(八) —— 在用户设备上下载和编译模型(一)
9. Core ML框架详细解析(九) —— 用一系列输入进行预测(一)
10. Core ML框架详细解析(十) —— 集成自定义图层(一)
11. Core ML框架详细解析(十一) —— 创建自定义图层(一)
12. Core ML框架详细解析(十二) —— 用scikit-learn开始机器学习(一)
13. Core ML框架详细解析(十三) —— 使用Keras和Core ML开始机器学习(一)
14. Core ML框架详细解析(十四) —— 使用Keras和Core ML开始机器学习(二)
15. Core ML框架详细解析(十五) —— 机器学习:分类(一)
16. Core ML框架详细解析(十六) —— 人工智能和IBM Watson Services(一)
17. Core ML框架详细解析(十七) —— Core ML 和 Vision简单示例(一)
18. Core ML框架详细解析(十八) —— 基于Core ML 和 Vision的设备上的训练(一)
源码
1. Swift
首先看一下工程组织结构
接着看下sb中的内容
下面就是源码了
1. UpdatableModel.swift
import CoreML
struct UpdatableModel {
// MARK: - Properties
private static var updatedDrawingClassifier: UpdatableDrawingClassifier?
private static let appDirectory =
FileManager.default.urls(for: .applicationSupportDirectory,
in: .userDomainMask).first!
private static let defaultModelURL =
UpdatableDrawingClassifier.urlOfModelInThisBundle
private static var updatedModelURL =
appDirectory.appendingPathComponent("personalized.mlmodelc")
private static var tempUpdatedModelURL =
appDirectory.appendingPathComponent("personalized_tmp.mlmodelc")
private init() { }
static var imageConstraint: MLImageConstraint {
let model = updatedDrawingClassifier ?? UpdatableDrawingClassifier()
return model.imageConstraint
}
}
// MARK: - Public Methods
extension UpdatableModel {
static func predictLabelFor(_ value: MLFeatureValue) -> String? {
loadModel()
return updatedDrawingClassifier?.predictLabelFor(value)
}
static func updateWith(
trainingData: MLBatchProvider,
completionHandler: @escaping () -> Void
) {
loadModel()
UpdatableDrawingClassifier.updateModel(
at: updatedModelURL,
with: trainingData) { context in
saveUpdatedModel(context)
DispatchQueue.main.async { completionHandler() }
}
}
}
// MARK: - Private Methods
private extension UpdatableModel {
static func saveUpdatedModel(_ updateContext: MLUpdateContext) {
let updatedModel = updateContext.model
let fileManager = FileManager.default
do {
try fileManager.createDirectory(
at: tempUpdatedModelURL,
withIntermediateDirectories: true,
attributes: nil)
try updatedModel.write(to: tempUpdatedModelURL)
_ = try fileManager.replaceItemAt(
updatedModelURL,
withItemAt: tempUpdatedModelURL)
print("Updated model saved to:\n\t\(updatedModelURL)")
} catch let error {
print("Could not save updated model to the file system: \(error)")
return
}
}
static func loadModel() {
let fileManager = FileManager.default
// Check if we need to copy over the initial model content
if !fileManager.fileExists(atPath: updatedModelURL.path) {
do {
let updatedModelParentURL =
updatedModelURL.deletingLastPathComponent()
try fileManager.createDirectory(
at: updatedModelParentURL,
withIntermediateDirectories: true,
attributes: nil)
let toTemp = updatedModelParentURL
.appendingPathComponent(defaultModelURL.lastPathComponent)
try fileManager.copyItem(at: defaultModelURL, to: toTemp)
try fileManager.moveItem(at: toTemp, to: updatedModelURL)
} catch {
print("Error: \(error)")
return
}
}
guard let model =
try? UpdatableDrawingClassifier(contentsOf: updatedModelURL) else {
return
}
updatedDrawingClassifier = model
}
}
// MARK: - UpdatableDrawingClassifier Extention
extension UpdatableDrawingClassifier {
var imageConstraint: MLImageConstraint {
return model.modelDescription
.inputDescriptionsByName["drawing"]!
.imageConstraint!
}
static func updateModel(
at url: URL,
with trainingData: MLBatchProvider,
completionHandler: @escaping (MLUpdateContext) -> Void
) {
do {
let updateTask = try MLUpdateTask(
forModelAt: url,
trainingData: trainingData,
configuration: nil,
completionHandler: completionHandler)
updateTask.resume()
} catch {
print("Could't create an MLUpdateTask.")
}
}
func predictLabelFor(_ value: MLFeatureValue) -> String? {
guard
let pixelBuffer = value.imageBufferValue,
let prediction = try? prediction(drawing: pixelBuffer).label
else {
return nil
}
if prediction == "unknown" {
print("No prediction found")
return nil
}
return prediction
}
}
2. Quote.swift
import Foundation
struct Quote {
// MARK: - Enums
enum Key {
static let text = "text"
static let author = "author"
static let keywords = "keywords"
}
// MARK: - Properties
var text: String
var author: String
var keywords: [String]
}
3. Drawing.swift
import Foundation
import PencilKit
import CoreML
struct Drawing {
// MARK: - Properties
private static let ciContext = CIContext()
let drawing: PKDrawing
let rect: CGRect
var featureValue: MLFeatureValue {
// Get the model's image constraints.
let imageConstraint = UpdatableModel.imageConstraint
// Get a white tinted version to use for the model
let preparedImage = whiteTintedImage
let imageFeatureValue = try? MLFeatureValue(cgImage: preparedImage, constraint: imageConstraint)
return imageFeatureValue!
}
private var whiteTintedImage: CGImage {
let ciContext = Drawing.ciContext
let parameters = [kCIInputBrightnessKey: 1.0]
let image = drawing.image(from: rect, scale: UIScreen.main.scale)
let ciImage = CIImage(
cgImage: image.cgImage!)
.applyingFilter("CIColorControls", parameters: parameters)
return ciContext.createCGImage(ciImage, from: ciImage.extent)!
}
}
4. DrawingsDataStore.swift
import CoreML
class DrawingDataStore: NSObject {
// MARK: - Properties
private var drawings: [Drawing?]
let emoji: String
var numberOfDrawings: Int {
return drawings.filter({ $0 != nil }).count
}
init(for emoji: String, capacity: Int) {
self.emoji = emoji
self.drawings = Array(repeating: nil, count: capacity)
}
}
// MARK: - Helper methods
extension DrawingDataStore {
func addDrawing(_ drawing: Drawing, at index: Int) {
drawings[index] = drawing
}
func prepareTrainingData() throws -> MLBatchProvider {
var featureProviders: [MLFeatureProvider] = []
let inputName = "drawing"
let outputName = "label"
for drawing in drawings {
if let drawing = drawing {
let inputValue = drawing.featureValue
let outputValue = MLFeatureValue(string: emoji)
let dataPointFeatures: [String: MLFeatureValue] =
[inputName: inputValue,
outputName: outputValue]
if let provider =
try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) {
featureProviders.append(provider)
}
}
}
return MLArrayBatchProvider(array: featureProviders)
}
}
5. DrawingView.swift
import UIKit
import PencilKit
class DrawingView: UIView {
// MARK: - Properties
var canvasView: PKCanvasView!
weak var delegate: DrawingViewDelegate?
// MARK: - Lifecycle
override init(frame: CGRect) {
super.init(frame: frame)
setupPencilKitCanvas()
}
required init?(coder aDecoder: NSCoder) {
super.init(coder: aDecoder)
setupPencilKitCanvas()
}
}
// MARK: - Private methods
private extension DrawingView {
func setupPencilKitCanvas() {
canvasView = PKCanvasView(frame: self.bounds)
canvasView.backgroundColor = .clear
canvasView.isOpaque = false
canvasView.delegate = self
canvasView.tool = PKInkingTool(.pen, color: .black, width: 10)
addSubview(canvasView)
}
}
// MARK: - Helper methods
extension DrawingView {
func boundingSquare() -> CGRect {
let rect = canvasView.drawing.bounds
let dimension = max(rect.size.width, rect.size.height)
// Adjust each dimension accordingly
let xInset = (rect.size.width - dimension) / 2
let yInset = (rect.size.height - dimension) / 2
// Perform the inset to get the square
return rect.insetBy(dx: xInset, dy: yInset)
}
func clearCanvas() {
canvasView.drawing = PKDrawing()
}
}
// MARK: - PKCanvasViewDelegate
extension DrawingView: PKCanvasViewDelegate {
func canvasViewDrawingDidChange(_ canvasView: PKCanvasView) {
// If this callback is a result of a cleared
// drawing, then ignore
let drawingRect = canvasView.drawing.bounds
guard drawingRect.size != .zero else {
return
}
delegate?.drawingDidChange(self)
}
}
// MARK: - Protocol
protocol DrawingViewDelegate: class {
func drawingDidChange(_ drawingView: DrawingView)
}`
6. EmojiViewCell.swift
import UIKit
class EmojiViewCell: UICollectionViewCell {
@IBOutlet weak var emojiLabel: UILabel!
var emoji: String? {
didSet {
emojiLabel?.text = emoji
}
}
}
7. CreateQuoteViewController.swift
import UIKit
import CoreML
import Vision
class CreateQuoteViewController: UIViewController {
// MARK: - Properties
@IBOutlet weak var imageView: UIImageView!
@IBOutlet weak var quoteTextView: UITextView!
@IBOutlet weak var addStickerButton: UIBarButtonItem!
@IBOutlet weak var stickerView: UIView!
@IBOutlet weak var starterLabel: UILabel!
var drawingView: DrawingView!
private lazy var quoteList: [Quote] = {
guard let path = Bundle.main.path(forResource: "Quotes", ofType: "plist")
else {
print("Failed to read Quotes.plist")
return []
}
let fileUrl = URL.init(fileURLWithPath: path)
guard let quotesArray = NSArray(contentsOf: fileUrl) as? [Dictionary<String, Any>]
else { return [] }
let quotes: [Quote] = quotesArray.compactMap { (quote) in
guard
let text = quote[Quote.Key.text] as? String,
let author = quote[Quote.Key.author] as? String,
let keywords = quote[Quote.Key.keywords] as? [String]
else { return nil }
return Quote(
text: text,
author: author,
keywords: keywords)
}
return quotes
}()
private lazy var stickerFrame: CGRect = {
let stickerHeightWidth = 50.0
let stickerOffsetX = Double(stickerView.bounds.midX) - (stickerHeightWidth / 2.0)
let stickerRect = CGRect(
x: stickerOffsetX,
y: 80.0, width:
stickerHeightWidth,
height: stickerHeightWidth)
return stickerRect
}()
private lazy var classificationRequest: VNCoreMLRequest = {
do {
let model = try VNCoreMLModel(for: SqueezeNet().model)
let request = VNCoreMLRequest(
model: model) { [weak self] request, error in
guard let self = self else {
return
}
self.processClassifications(for: request, error: error)
}
request.imageCropAndScaleOption = .centerCrop
return request
} catch {
fatalError("Failed to load Vision ML model: \(error)")
}
}()
// MARK: - Lifecycle
override func viewDidLoad() {
super.viewDidLoad()
quoteTextView.isHidden = true
addStickerButton.isEnabled = false
addCanvasForDrawing()
drawingView.isHidden = true
}
// MARK: - Actions
@IBAction func selectPhotoPressed(_ sender: Any) {
let picker = UIImagePickerController()
picker.delegate = self
picker.sourceType = .photoLibrary
picker.modalPresentationStyle = .overFullScreen
present(picker, animated: true)
}
@IBAction func cancelPressed(_ sender: Any) {
dismiss(animated: true)
}
@IBAction func addStickerDoneUnwind(_ unwindSegue: UIStoryboardSegue) {
guard
let sourceViewController = unwindSegue.source as? AddStickerViewController,
let selectedEmoji = sourceViewController.selectedEmoji
else {
return
}
addStickerToCanvas(selectedEmoji, at: stickerFrame)
}
}
// MARK: - Private methods
private extension CreateQuoteViewController {
func addStickerToCanvas(_ sticker: String, at rect: CGRect) {
let stickerLabel = UILabel(frame: rect)
stickerLabel.text = sticker
stickerLabel.font = .systemFont(ofSize: 100)
stickerLabel.numberOfLines = 1
stickerLabel.baselineAdjustment = .alignCenters
stickerLabel.textAlignment = .center
stickerLabel.adjustsFontSizeToFitWidth = true
// Add sticker to the canvas
stickerView.addSubview(stickerLabel)
}
func clearStickersFromCanvas() {
for view in stickerView.subviews {
view.removeFromSuperview()
}
}
func addCanvasForDrawing() {
drawingView = DrawingView(frame: stickerView.bounds)
drawingView.delegate = self
view.addSubview(drawingView)
drawingView.translatesAutoresizingMaskIntoConstraints = false
NSLayoutConstraint.activate([
drawingView.topAnchor.constraint(equalTo: stickerView.topAnchor),
drawingView.leftAnchor.constraint(equalTo: stickerView.leftAnchor),
drawingView.rightAnchor.constraint(equalTo: stickerView.rightAnchor),
drawingView.bottomAnchor.constraint(equalTo: stickerView.bottomAnchor)
])
}
func getQuote(for keywords: [String]? = nil) -> Quote? {
if let keywords = keywords {
for keyword in keywords {
for quote in quoteList {
if quote.keywords.contains(keyword) {
return quote
}
}
}
}
return selectRandomQuote()
}
func selectRandomQuote() -> Quote? {
if let quote = quoteList.randomElement() {
return quote
}
return nil
}
func processClassifications(for request: VNRequest, error: Error?) {
DispatchQueue.main.async {
if let classifications =
request.results as? [VNClassificationObservation] {
let topClassifications = classifications.prefix(2).map {
(confidence: $0.confidence, identifier: $0.identifier)
}
print("Top classifications: \(topClassifications)")
let topIdentifiers =
topClassifications.map {$0.identifier.lowercased()
}
if let quote = self.getQuote(for: topIdentifiers) {
self.quoteTextView.text = quote.text
}
}
}
}
func classifyImage(_ image: UIImage) {
guard let orientation = CGImagePropertyOrientation(
rawValue: UInt32(image.imageOrientation.rawValue)) else {
return
}
guard let ciImage = CIImage(image: image) else {
fatalError("Unable to create \(CIImage.self) from \(image).")
}
DispatchQueue.global(qos: .userInitiated).async {
let handler =
VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
do {
try handler.perform([self.classificationRequest])
} catch {
print("Failed to perform classification.\n\(error.localizedDescription)")
}
}
}
}
// MARK: - UIImagePickerControllerDelegate
extension CreateQuoteViewController: UIImagePickerControllerDelegate, UINavigationControllerDelegate {
func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) {
picker.dismiss(animated: true)
let image = info[UIImagePickerController.InfoKey.originalImage] as! UIImage
imageView.image = image
quoteTextView.isHidden = false
addStickerButton.isEnabled = true
drawingView.isHidden = false
starterLabel.isHidden = true
clearStickersFromCanvas()
classifyImage(image)
}
}
// MARK - UIGestureRecognizerDelegate
extension CreateQuoteViewController: UIGestureRecognizerDelegate {
@objc func handlePanGesture(_ recognizer: UIPanGestureRecognizer) {
let translation = recognizer.translation(in: stickerView)
if let view = recognizer.view {
view.center = CGPoint(
x:view.center.x + translation.x,
y:view.center.y + translation.y)
}
recognizer.setTranslation(CGPoint.zero, in: stickerView)
if recognizer.state == UIGestureRecognizer.State.ended {
let velocity = recognizer.velocity(in: stickerView)
let magnitude =
sqrt((velocity.x * velocity.x) + (velocity.y * velocity.y))
let slideMultiplier = magnitude / 200
let slideFactor = 0.1 * slideMultiplier
var finalPoint = CGPoint(
x:recognizer.view!.center.x + (velocity.x * slideFactor),
y:recognizer.view!.center.y + (velocity.y * slideFactor))
finalPoint.x =
min(max(finalPoint.x, 0), stickerView.bounds.size.width)
finalPoint.y =
min(max(finalPoint.y, 0), stickerView.bounds.size.height)
UIView.animate(
withDuration: Double(slideFactor * 2),
delay: 0,
options: UIView.AnimationOptions.curveEaseOut,
animations: { recognizer.view!.center = finalPoint },
completion: nil)
}
}
}
// MARK: - DrawingViewDelegate
extension CreateQuoteViewController: DrawingViewDelegate {
func drawingDidChange(_ drawingView: DrawingView) {
let drawingRect = drawingView.boundingSquare()
let drawing = Drawing(drawing: drawingView.canvasView.drawing, rect: drawingRect)
let imageFeatureValue = drawing.featureValue
let drawingLabel = UpdatableModel.predictLabelFor(imageFeatureValue)
DispatchQueue.main.async {
drawingView.clearCanvas()
guard let emoji = drawingLabel else {
return
}
self.addStickerToCanvas(emoji, at: drawingRect)
}
}
}
8. AddStickerViewController.swift
import UIKit
private let reuseIdentifier = "EmojiCell"
let emoji = """
😀,😂,🤣,😅,😆,😉,😊,😎,😍,😘,🥰,🙂,
🤗,🤩,🤔,😐,😑,🙄,😣,😥,😮,🤐,😯,😪,
😫,😴,😛,😝,😓,🤑,☹️,😖,😤,😭,😨,😩,
🤯,😬,😰,😱,🥵,🥶,😳,🤪,😵,😡,🤬,😷,
🤒,🤕,🤢,🤮,🤧,😇,🤠,🤡,🥳,🤫,😈,👿,
👹,👺,💀,👻,👽,🤖,💩,🙌🏼,🙏🏼,👍🏼,👎🏼,👊🏼,
👋🏼,🤙🏼,💪🏼
"""
.components(separatedBy: ",")
class AddStickerViewController: UICollectionViewController {
var selectedEmoji: String?
// MARK: - Lifecycle
override func prepare(for segue: UIStoryboardSegue, sender: Any?) {
if segue.identifier == "AddShortcutSegue" {
if let destinationViewController =
segue.destination as? AddShortcutViewController,
let selectedEmoji = selectedEmoji {
destinationViewController.selectedEmoji = selectedEmoji
}
}
}
// MARK: - Actions
@IBAction func cancelPressed(_ sender: Any) {
dismiss(animated: true)
}
}
// MARK: - UICollectionViewDataSource
extension AddStickerViewController {
override func collectionView(
_ collectionView: UICollectionView,
numberOfItemsInSection section: Int
) -> Int {
return emoji.count
}
override func collectionView(
_ collectionView: UICollectionView,
cellForItemAt indexPath: IndexPath
) -> UICollectionViewCell {
let cell = collectionView.dequeueReusableCell(withReuseIdentifier: reuseIdentifier, for: indexPath)
if let cell = cell as? EmojiViewCell {
cell.emoji = emoji[indexPath.item]
}
return cell
}
}
// MARK: - UICollectionViewDelegate
extension AddStickerViewController {
override func collectionView(
_ collectionView: UICollectionView,
didSelectItemAt indexPath: IndexPath
) {
selectedEmoji = emoji[indexPath.item]
performSegue(withIdentifier: "AddShortcutSegue", sender: self)
}
}
9. AddShortcutViewController.swift
import UIKit
class AddShortcutViewController: UIViewController {
// MARK: - Properties
@IBOutlet weak var emojiLabel: UILabel!
@IBOutlet weak var drawingView1: DrawingView!
@IBOutlet weak var drawingView2: DrawingView!
@IBOutlet weak var drawingView3: DrawingView!
@IBOutlet weak var saveButton: UIBarButtonItem!
var selectedEmoji: String?
private var drawingViews: [DrawingView] = []
private var minimumNumberOfDrawingsRequired: Int {
return drawingViews.count
}
private var drawingDataStore: DrawingDataStore!
// MARK: - Lifecycle
override func viewDidLoad() {
super.viewDidLoad()
drawingViews = [drawingView1, drawingView2, drawingView3]
drawingView1?.delegate = self
drawingView2?.delegate = self
drawingView3?.delegate = self
if let selectedEmoji = selectedEmoji {
emojiLabel.text = selectedEmoji
drawingDataStore = DrawingDataStore(
for: selectedEmoji,
capacity: minimumNumberOfDrawingsRequired)
}
saveButton.isEnabled = false
}
// MARK: - Actions
@IBAction func savePressed(_ sender: Any) {
do {
let trainingData = try drawingDataStore.prepareTrainingData()
DispatchQueue.global(qos: .userInitiated).async {
UpdatableModel.updateWith(trainingData: trainingData) {
DispatchQueue.main.async {
self.performSegue(
withIdentifier: "AddShortcutUnwindSegue",
sender: self)
}
}
}
} catch {
print("Error updating model", error)
}
}
}
// MARK: - DrawingViewDelegate
extension AddShortcutViewController: DrawingViewDelegate {
func drawingDidChange(_ drawingView: DrawingView) {
if let index = drawingViews.firstIndex(of: drawingView) {
let drawingRect = drawingView.boundingSquare()
let drawing = Drawing(drawing: drawingView.canvasView.drawing, rect: drawingRect)
drawingDataStore.addDrawing(drawing, at: index)
saveButton.isEnabled =
drawingDataStore.numberOfDrawings >= minimumNumberOfDrawingsRequired
}
}
}
后记
本篇主要讲述了基于Core ML 和 Vision的设备上的训练,感兴趣的给个赞或者关注~~~