Swift Algorithm Club: Strassen’s Algorithm

In this tutorial, you’ll learn how to implement Strassen’s Matrix Multiplication in Swift. This was the first matrix multiplication algorithm to beat the naive O(n³) implementation, and is a fantastic example of the Divide and Conquer coding paradigm — a favorite topic in coding interviews. By Richard Ash.

Leave a rating/review
Download materials
Save for later
Share
You are currently viewing page 3 of 4 of this article. Click here to view the first page.

Implementing in Swift

Now, to the implementation! Start by adding the following extension to the bottom of your playground:

extension Matrix {
  public func strassenMatrixMultiply(by other: Matrix) -> Matrix {
    // More code to come!
  }
}

Now, just like in the naive implementation, you need to check that the first matrix’s column count is equal to the second matrix’s row count.

Replace the comment with the following:

precondition(columnCount == other.rowCount, """
      Two matrices can only be matrix multiplied if the first column's count is \
      equal to the second's row count.
      """)

Time for some prep work! Add the following right below precondition:

// 1
let n = Swift.max(rowCount, columnCount, other.rowCount, other.columnCount)
// 2
let m = nextPowerOfTwo(after: n)

// 3
var firstPrep = Matrix(rows: m, columns: m)
var secondPrep = Matrix(rows: m, columns: m)

// 4
for index in indices {
  firstPrep[index.row, index.column] = self[index]
}
for index in other.indices {
  secondPrep[index.row, index.column] = other[index]
}

Reviewing what’s going on here, you:

  1. Calculate the max count of the first or second matrix’s rows or columns.
  2. Find the next power of two after that number.
  3. Create two new matrices whose rows and columns are equal to the next power of two.
  4. Copy the elements from the first and second matrices into their respective prep matrices.

This seems like extra work: Why is this necessary? Great question! Next, you’ll investigate with an example.

Say you have a 3×2 matrix, A. How should you split this up? Should the middle row go with the top split or the bottom? Because there’s no even way to split this matrix, this edge case would need to be explicitly handled. While this seems difficult, the above prep work will remove this possibility completely.

By increasing the size of the matrix until it is a square matrix whose rows/columns are an even power of two, you ensure the edge case will never occur. Additionally, because the prep work only adds rows and columns with zeros, the result won’t change at all. 🎉

Now, to finish the method, add the following to strassenMatrixMultiply:

// 1
let resultPrep = firstPrep.strassenRecursive(by: secondPrep)
// 2
var result = Matrix(rows: rowCount, columns: other.columnCount)
// 3
for index in result.indices {
  result[index] = resultPrep[index.row, index.column]
}
// 4
return result

Here, you:

  1. Recursively compute the result matrix.
  2. Initialize a new matrix with the correct dimensions.
  3. Iterate through the result matrix and copy over the identical index from the prep matrix.
  4. Finally, you return the result!

Good job! Almost done. You have two unimplemented methods left, nextPowerOfTwo and strassenRecursive. You’ll tackle those now.

nextPowerOfTwo

Add the following method below strassenMatrixMultiply:

private func nextPowerOfTwo(after n: Int) -> Int {
  // 1
  let logBaseTwoN = log2(Double(n))
  // 2
  let ceilLogBaseTwoN = ceil(logBaseTwoN)
  // 3
  let nextPowerOfTwo = pow(2, ceilLogBaseTwoN)
  return Int(nextPowerOfTwo)
}

This method takes a number and returns the next power of two after that number if that number is not already an even power of two.

Reviewing, you:

  1. Calculate the log base 2 of the inputed number.
  2. Take the ceiling of logBaseTwoN. This rounds the logBaseTwoN up to the nearest whole number.
  3. Calculate 2 to the ceilLogBaseTwoN power and convert it to an Int.

Challenge

To get a better idea of how this method works, try applying it to the following numbers. Don’t use code! Write out each step and use WolframAlpha to do the calculations.

  1. 3
  2. 4

[spoiler title=”Solution”]
For 3:

log2(3) = 1.584
ceil(1.584) = 2
pow(2, 2) = 4
nextPowerOfTwo = 4

For 4:

log2(4) = 2
ceil(2) = 2
pow(2, 2) = 4
nextPowerOfTwo = 4
[/spoiler]

strassenRecursive

Next up, you need to implement strassenRecursive(by other:). Start by adding the following below nextPowerOfTwo:

private func strassenRecursive(by other: Matrix) -> Matrix {
  assert(isSquare && other.isSquare, "This method requires square matrices!")
  guard rowCount > 1 && other.rowCount > 1 else { return self * other }
}

Here, you set the base case for the recursion: If either matrix has a row length of 1, then you just return the term-by-term multiplication of the two matrices.

Then, you need to split the input matrices into 8 submatrices. Add this initialization to the method:

// 1
let n = rowCount
let nBy2 = n / 2

//  Assume submatrices are allocated as follows
//  matrix self = |a b|,    matrix other = |e f|
//                |c d|                    |g h|

// 2
var a = Matrix(rows: nBy2, columns: nBy2)
var b = Matrix(rows: nBy2, columns: nBy2)
var c = Matrix(rows: nBy2, columns: nBy2)
var d = Matrix(rows: nBy2, columns: nBy2)
var e = Matrix(rows: nBy2, columns: nBy2)
var f = Matrix(rows: nBy2, columns: nBy2)
var g = Matrix(rows: nBy2, columns: nBy2)
var h = Matrix(rows: nBy2, columns: nBy2)

// 3
for i in 0..<nBy2 {
  for j in 0..<nBy2 {
    a[i, j] = self[i, j]
    b[i, j] = self[i, j+nBy2]
    c[i, j] = self[i+nBy2, j]
    d[i, j] = self[i+nBy2, j+nBy2]
    e[i, j] = other[i, j]
    f[i, j] = other[i, j+nBy2]
    g[i, j] = other[i+nBy2, j]
    h[i, j] = other[i+nBy2, j+nBy2]
  }
}

OK! You:

  1. Initialize two variables that are the size of the current matrix and the size of the submatrices.
  2. Initialize all eight submatrices.
  3. Update each of the eight submatrices with the appropriate elements from the original matrices. A really cool optimization here is that you only need to loop from 0..<nBy2 instead of 0..<n. Because the eight submatrices have the same indices, you're able to update all eight each step through the for loop!

Next, add the following to the bottom of the method:

let p1 = a.strassenRecursive(by: f-h)       // a * (f - h)
let p2 = (a+b).strassenRecursive(by: h)     // (a + b) * h
let p3 = (c+d).strassenRecursive(by: e)     // (c + d) * e
let p4 = d.strassenRecursive(by: g-e)       // d * (g - e)
let p5 = (a+d).strassenRecursive(by: e+h)   // (a + d) * (e + h)
let p6 = (b-d).strassenRecursive(by: g+h)   // (b - d) * (g + h)
let p7 = (a-c).strassenRecursive(by: e+f)   // (a - c) * (e + f)

Here, you recursively compute the seven matrix multiplications required by Strassen's algorithm. They are the exact same seven you saw in the section above!

Next, add the following:

let result11 = p5 + p4 - p2 + p6         // p5 + p4 - p2 + p6
let result12 = p1 + p2                   // p1 + p2
let result21 = p3 + p4                   // p3 + p4
let result22 = p1 + p5 - p3 - p7         // p1 + p5 - p3 - p7

Above, you compute the submatrices of the result matrix.  Now for the final step! Add the following:

var result = Matrix(rows: n, columns: n)
for i in 0..<nBy2 {
  for j in 0..<nBy2 {
    result[i, j]           = result11[i, j]
    result[i, j+nBy2]      = result12[i, j]
    result[i+nBy2, j]      = result21[i, j]
    result[i+nBy2, j+nBy2] = result22[i, j]
  }
}

return result

Phew! Good work. In the final step, you combine the four submatrices into your result matrix. Notice that you only need to loop from
0..<nBy2 because, in each iteration of the loop, you can fill four elements of final result matrix. Yay for efficiency!