🍎 Honeycrisp

文档 | 示例 | 软件包索引页

自动微分和神经网络,全部使用 Swift 语言,专为 Apple 芯片设计。

示例

请参阅 honeycrisp-examples 以获取深入的使用示例。

张量和运算

我们可以创建一个具有形状和数据的张量

// Create a 2x3 matrix:
//   1  2  3
//   4  5  6
let matrix = Tensor(data: [1, 2, 3, 4, 5, 6], shape: [2, 3])

您可以对张量执行运算以获得新的张量

let matrixPlus1 = matrix + 1
let sumOfColumns = matrix.sum(axis: 1)

我们可以使用 try await 从张量中获取数据

// Print a [Float] from the raw data of the matrix
print("data as floats:", try await matrix.floats())

可互换的后端

我们可以在不同的后端运行我们计算的不同部分

Backend.defaultBackend = try MPSBackend()  // Use the GPU by default
let cpuBackend = CPUBackend()
let x = Tensor(rand: [128, 128])  // Performed on GPU
let y = cpuBackend.use { x + 3 }  // Performed on CPU
let z = y - 3  // Performed on GPU

完整训练示例

这是一个在简单目标上训练虚拟模型的完整示例。

首先,我们定义一个带有可训练参数和子模块的模型

class MyModel: Trainable {
  // A parameter which will be tracked automatically
  @Param var someParameter: Tensor

  // We can also give parameters custom names
  @Param(name: "customName") var otherParameter: Tensor

  // A sub-module whose parameters will also be tracked
  @Child var someLayer: Linear

  override init() {
    super.init()
    self.someParameter = Tensor(data: [1.0])
    self.otherParameter = Tensor(zeros: [7])
    self.someLayer = Linear(inCount: 3, outCount: 7)
  }

  func callAsFunction(_ input: Tensor) -> Tensor {
    // We can access properties like normal
    return someParameter * (someLayer(input) + otherParameter)
  }
}

训练循环如下所示

@main
struct Main {
  static func main() async {
    do {
      let model = MyModel()
      let optimizer = Adam(model.parameters, lr: 0.1)

      // We will use the same input batch for all iterations.
      let batchSize = 8
      let input = Tensor(rand: [batchSize, 3])

      for i in 0..<10 {
        let output = model(input)
        let loss = output.pow(2).mean()
        loss.backward()
        optimizer.step()
        optimizer.clearGrads()
        print("step \(i): loss=\(try await loss.item())")
      }
    } catch {
      print("FATAL ERROR: \(error)")
    }
  }
}