Swift Algorithm Club: Strassen’s Algorithm

In this tutorial, you’ll learn how to implement Strassen’s Matrix Multiplication in Swift. This was the first matrix multiplication algorithm to beat the naive O(n³) implementation, and is a fantastic example of the Divide and Conquer coding paradigm — a favorite topic in coding interviews. By Richard Ash.

Leave a rating/review
Download materials
Save for later
Share
You are currently viewing page 4 of 4 of this article. Click here to view the first page.

Example

To illustrate how the divide and conquer recursion works, look at the following the example:

Above are two matrices, a 3x2 matrix A and a 2x4 matrix B. You'll use strassenMatrixMultiply(by:) to calculate the matrix multiplication. Following strassenMatrixMultiply(by:), first you prep the matrices, A & B into 2 4x4 matrices.

Once the matrices are prepped, you begin calling the recursive portion of the algorithm, strassenRecursive(by:). Above is the full recursion tree for strassenRecursive(by:) for this example. Each call to strassenRecursive(by:) generates 7 additional calls, each with a matrix half the size of the inputed matrix. This is the heart of Strassen's Algorithm, and where the divide and conquer strategy is used. You recursively split the matrices in two halves, and the solve (conquer) each bit before combining the results.

Challenge

How many calls are there to strassenRecursive(by:) for this example?

[spoiler title="Solution"]

(7³) + 1 = 344. You get 7 because each call to strassenRecursive(by:) generates 7 additional calls. The n³ comes from the 3 layers of recursion (4 -> 2 -> 1). Finally, you need to add one to account for the first call, from strassenMatrixMultiply(by:).
[/spoiler]

Next, you'll review the first branch in detail.

Following the first call to strassenRecursive(by:), you first split the matrices APrep and BPrep into 8 submatrices.

You'll follow the first branch, a.strassenRecursive(by: f-h). This time self = a and other = f - h. You'll then split into the 8 submatrices again.

As before, follow the first branch, a.strassenRecursive(by: f-h). This time self = a' and other = f' - h'. Recall the line at the beginning of strassenRecursive(by:):

guard rowCount > 1 && other.rowCount > 1 else { return self * other }

Now, because the matrices only have one row/column, you just multiply the two elements together and return the result!

The result then propagates upwards and is used in the previous recursion.

You could repeat this procedure for each and every recursion but that might take all day 😅. Good thing computers are much faster!

Time Complexity

As before, you can analyze the time complexity using the Master TheoremT(n) = 7T(n/2) + O(n²) which leads to O(n^log(7)) complexity. This comes out to approximately O(n^2.8074) which is better than O(n³). 😁💯🙌

Trying It Out

Now that you've done all this work, try the method out! Add the following to the bottom of your playground:

let G = A.strassenMatrixMultiply(by: B)
printMatrix(G, name: "G")

If you remember from above, you actually ran this multiplication before, in matrix C. Check your answer to ensure the two agree. The output should look like:

Matrix G:
9 7
20 10

Now, one more. Add the following to the bottom of your playground:

let H = B.matrixMultiply(by: A)
printMatrix(H, name: "H")
let I = B.strassenMatrixMultiply(by: A)
printMatrix(I, name: "I")

Your output should look like:

Matrix H: 
6 43 -3 0
4 12 -2 0
-2 19 1 0
4 72 -2 0

Matrix I: 
6 43 -3 0
4 12 -2 0
-2 19 1 0
4 72 -2 0

Challenge

1. Initialize the following matrices J and K:

Matrix J:
1 2 3 8 -1
-1 18 2 0 1

Matrix K:
-1 2 98
3 4 4
0 1 2
9 6 5
3 1 -5

2. Compute L, the result of applying matrixMultiply to J and K.
3. Compute M, the result of applying strassenMatrixMultiply to J and K.
4. Print L and M to the console to check that they are equal.

[spoiler="Solution"]

// 1
var J = Matrix<Int>(rows: 2, columns: 5)
J[row: 0] = [1, 2, 3, 8, -1]
J[row: 1] = [-1, 18, 2, 0, 1]

var K = Matrix<Int>(rows: 5, columns: 3)
K[column: 0] = [-1, 3, 0, 9, 3]
K[column: 1] = [2, 4, 1, 6, 1]
K[column: 2] = [98, 4, 2, 5, -5]

// 2
let L = J.matrixMultiply(by: K)
// 3
let M = J.strassenMatrixMultiply(by: K)
// 4
printMatrix(L, name: "L")
printMatrix(M, name: "M")

[/spoiler]

And that's the advantage of Strassen's algorithm!

Where to Go From Here?

You'll find the completed playground in the Download Materials button at the top or bottom of the tutorial. It has all the code you've already implemented above. You can also find the original implementation and further discussion in the Strassen's Matrix Multiplication section of the Swift Algorithm Club repository.

For more practice with divide and conquer algorithms, check out Karatsuba Multiplication or Merge Sort, both of which are implemented in the Swift Algorithm Club repository.

If you're interested in faster matrix multiplication algorithms, look at Coppersmith–Winograd algorithm. It's the faster known matrix multiplication algorithm and has about a O(n^2.372) complexity.

This was just one of the many algorithms in the Swift Algorithm Club repository. If you're interested in more, check out the repo.

It's in your best interest to know about algorithms and data structures — they're solutions to many real-world problems and are frequently asked as interview questions. Plus, they're fun!

Stay tuned for more tutorials from the Swift Algorithm Club in the future. In the meantime, if you have any questions on implementing Strassen's algorithm in Swift, please join the forum discussion below!

Note: The Swift Algorithm Club is always looking for more contributors. If you've got an interesting data structure, algorithm, or even an interview question to share, don't hesitate to contribute! To learn more about the contribution process, check out our Join the Swift Algorithm Club article.