1
0

Added matrix implementation

This commit is contained in:
neviyn 2020-07-30 21:39:21 +01:00
parent b703b7dd09
commit 194368e7a2

255
src/matrix.nim Normal file
View File

@ -0,0 +1,255 @@
import strformat
from "./tuple" import `=~`
type
Matrix* = object
matrix: seq[float]
width: int
height: int
MatrixInitialisationException = object of CatchableError
MatrixMultiplicationException = object of CatchableError
IdentityMatrixCreationError = object of CatchableError
MatrixInversionError = object of CatchableError
proc matrix(width, height: int): Matrix =
result.matrix = newSeq[float](width * height)
result.width = width
result.height = height
proc matrix*(data: seq[float], width, height: int): Matrix =
if len(data) != width * height:
raise newException(MatrixInitialisationException, &"Matrix dimensions did not match size of input data sequence, wanted {width * height} but it was {len(data)}")
result.matrix = data
result.width = width
result.height = height
template `[]`*(mat: Matrix, row, column: int): float = mat.matrix[row * mat.width + column]
template `[]=`*(mat: var Matrix, row, column: int, val: float) = mat.matrix[row * mat.width + column] = val
proc `==`*(lhs, rhs: Matrix): bool =
if lhs.width != rhs.width or lhs.height != rhs.height: return false
for i in 0..<len(lhs.matrix):
if not(lhs.matrix[i] =~ rhs.matrix[i]):
return false
return true
proc `*`*(lhs, rhs: Matrix): Matrix =
if lhs.width != rhs.height:
raise newException(MatrixMultiplicationException, "Cannot multiply matricies of these sizes")
result.width = rhs.width
result.height = lhs.height
result.matrix = newSeq[float](result.width * result.height)
for row in 0..<result.height:
for col in 0..<result.width:
var collect: float = 0.0
for i in 0..<lhs.width:
collect += lhs[row, i] * rhs[i, col]
result[row, col] = collect
template isSquare(mat: Matrix): bool = mat.height == mat.width and len(mat.matrix) == (mat.height * mat.width)
proc identity(size: int): Matrix =
result.width = size
result.height = size
result.matrix = newSeq[float](size * size)
for i in 0..<size:
result[i, i] = 1.0
proc identity(mat: Matrix): Matrix =
if not isSquare(mat):
raise newException(IdentityMatrixCreationError, "Can only create an identity for a square matrix")
result = identity(mat.height)
proc transpose(mat: Matrix): Matrix =
result.width = mat.height
result.height = mat.width
result.matrix = newSeq[float](result.width * result.height)
for j in 0..<mat.height:
for i in 0..<mat.width:
result[i, j] = mat[j, i]
proc cofactor(mat: Matrix, row, col: int): float
proc determinant(mat: Matrix): float =
if mat.width == 2 and mat.height == 2:
result = mat.matrix[0] * mat.matrix[3] - mat.matrix[1] * mat.matrix[2]
else:
for col in 0..<mat.width:
result += mat[0, col] * mat.cofactor(0, col)
template indexToCoordinate(mat: Matrix, index: int): (int, int) = (index div mat.width, index mod mat.width)
proc submatrix(mat: Matrix, row, col: int): Matrix =
result.width = mat.width - 1
result.height = mat.height - 1
for i in 0..<(mat.width * mat.height):
let coordinate = mat.indexToCoordinate(i)
if not (row == coordinate[0] or col == coordinate[1]):
result.matrix.add(mat.matrix[i])
proc minor(mat: Matrix, row, col: int): float = mat.submatrix(row, col).determinant
proc cofactor(mat: Matrix, row, col: int): float =
result = mat.minor(row, col)
if (row + col) mod 2 == 1:
result = result * -1
template isInvertable(mat: Matrix): bool = mat.determinant != 0
proc inverse(mat: Matrix): Matrix =
if not isSquare(mat):
raise newException(MatrixInversionError, "Can only invert a square matrix")
if not isInvertable(mat):
raise newException(MatrixInversionError, "Cannot invert this matrix, deteminant is zero")
result.width = mat.width
result.height = mat.height
result.matrix = newSeq[float](result.width * result.height)
for row in 0..<mat.width:
for col in 0..<mat.height:
result[col, row] = mat.cofactor(row, col) / mat.determinant
when isMainModule:
import unittest
suite "matrix":
test "Matrix creation 4x4":
let mat = matrix(@[1.0, 2.0, 3.0, 4.0, 5.5, 6.5, 7.5, 8.5, 9.0, 10.0, 11.0, 12.0, 13.5, 14.5, 15.5, 16.5], 4, 4)
check(mat[0,0] == 1.0)
check(mat[0,3] == 4.0)
check(mat[1,0] == 5.5)
check(mat[1,2] == 7.5)
check(mat[2,2] == 11.0)
check(mat[3,0] == 13.5)
check(mat[3,2] == 15.5)
test "Matrix creation 2x2":
let mat = matrix(@[-3.0, 5.0, 1.0, -2.0], 2, 2)
check(mat[0,0] == -3.0)
check(mat[0,1] == 5.0)
check(mat[1,0] == 1.0)
check(mat[1,1] == -2.0)
test "Matrix creation 3x3":
let mat = matrix(@[-3.0, 5.0, 0.0, 1.0, -2.0, -7.0, 0.0, 1.0, 1.0], 3, 3)
check(mat[0,0] == -3.0)
check(mat[1,1] == -2.0)
check(mat[2,2] == 1.0)
test "Matrix equality":
let mat1 = matrix(@[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0], 4, 4)
let mat2 = matrix(@[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0], 4, 4)
check(mat1 == mat2)
test "Matrix inequality":
let mat1 = matrix(@[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0], 4, 4)
let mat2 = matrix(@[2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0], 4, 4)
check(mat1 != mat2)
test "Matrix multiplication":
let mat1 = matrix(@[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0], 4, 4)
let mat2 = matrix(@[-2.0, 1.0, 2.0, 3.0, 3.0, 2.0, 1.0, -1.0, 4.0, 3.0, 6.0, 5.0, 1.0, 2.0, 7.0, 8.0], 4, 4)
let result = matrix(@[20.0, 22.0, 50.0, 48.0, 44.0, 54.0, 114.0, 108.0, 40.0, 58.0, 110.0, 102.0, 16.0, 26.0, 46.0, 42.0], 4, 4)
check(mat1 * mat2 == result)
test "Matrix multiplication, different dimensions":
let mat1 = matrix(@[1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 4.0, 2.0, 8.0, 6.0, 4.0, 1.0, 0.0, 0.0, 0.0, 1.0], 4, 4)
let mat2 = matrix(@[1.0, 2.0, 3.0, 1.0], 1, 4)
let result = matrix(@[18.0, 24.0, 33.0, 1.0], 1, 4)
check(mat1 * mat2 == result)
test "Matrix multiplication, different dimensions 2":
let mat1 = matrix(@[3.0, 4.0, 2.0], 3, 1)
let mat2 = matrix(@[13.0, 9.0, 7.0, 15.0, 8.0, 7.0, 4.0, 6.0, 6.0, 4.0, 0.0, 3.0], 4, 3)
let result = matrix(@[83.0, 63.0, 37.0, 75.0], 4, 1)
check(mat1 * mat2 == result)
test "Identity matrix creation":
check(identity(3) == matrix(@[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3))
check(identity(4) == matrix(@[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], 4, 4))
test "Identity matrix multiplication":
let mat1 = matrix(@[0.0, 1.0, 2.0, 4.0, 1.0, 2.0, 4.0, 8.0, 2.0, 4.0, 8.0, 16.0, 4.0, 8.0, 16.0, 32.0], 4, 4)
check(mat1 * mat1.identity == mat1)
let mat2 = matrix(@[1.0, 2.0, 3.0, 4.0], 4, 1)
check(mat2 * mat1.identity == mat2)
test "Matrix transposition":
let mat1 = matrix(@[0.0, 9.0, 3.0, 0.0, 9.0, 8.0, 0.0, 8.0, 1.0, 8.0, 5.0, 3.0, 0.0, 0.0, 5.0, 8.0], 4, 4)
let mat2 = matrix(@[0.0, 9.0, 1.0, 0.0, 9.0, 8.0, 8.0, 0.0, 3.0, 0.0, 5.0, 5.0, 0.0, 8.0, 3.0, 8.0], 4, 4)
check(mat1.transpose == mat2)
test "Identity matrix transposition":
let mat = identity(4)
check(mat.transpose == mat)
test "2x2 Matrix determinant":
check(matrix(@[1.0, 5.0, -3.0, 2.0], 2, 2).determinant =~ 17.0)
test "Matrix internal index to coordinate space":
let mat1 = matrix(@[0.0, 9.0, 3.0, 0.0, 9.0, 8.0, 0.0, 8.0, 1.0, 8.0, 5.0, 3.0, 0.0, 0.0, 5.0, 8.0], 4, 4)
let coord = mat1.indexToCoordinate(5)
check(mat1[1, 1] == mat1[coord[0], coord[1]])
test "Submatrix of a 3x3 matrix":
let mat1 = matrix(@[1.0, 5.0, 0.0, -3.0, 2.0, 7.0, 0.0, 6.0, -3.0], 3, 3)
let mat2 = matrix(@[-3.0, 2.0, 0.0, 6.0], 2, 2)
check(mat1.submatrix(0, 2) == mat2)
test "Submatrix of a 4x4 matrix":
let mat1 = matrix(@[-6.0, 1.0, 1.0, 6.0, -8.0, 5.0, 8.0, 6.0, -1.0, 0.0, 8.0, 2.0, -7.0, 1.0, -1.0, 1.0], 4, 4)
let mat2 = matrix(@[-6.0, 1.0, 6.0, -8.0, 8.0, 6.0, -7.0, -1.0, 1.0], 3, 3)
check(mat1.submatrix(2, 1) == mat2)
test "Minor of a 3x3 matrix":
let mat1 = matrix(@[3.0, 5.0, 0.0, 2.0, -1.0, -7.0, 6.0, -1.0, 5.0], 3, 3)
let submat = mat1.submatrix(1, 0)
check(submat.determinant == 25)
check(submat.determinant == mat1.minor(1, 0))
test "Cofactor":
let mat1 = matrix(@[3.0, 5.0, 0.0, 2.0, -1.0, -7.0, 6.0, -1.0, 5.0], 3, 3)
check(mat1.minor(0,0) =~ -12.0)
check(mat1.cofactor(0,0) =~ -12.0)
check(mat1.minor(1,0) =~ 25.0)
check(mat1.cofactor(1,0) =~ -25.0)
test "3x3 matrix determinant":
let mat1 = matrix(@[1.0, 2.0, 6.0, -5.0, 8.0, -4.0, 2.0, 6.0, 4.0], 3, 3)
check(mat1.cofactor(0,0) =~ 56.0)
check(mat1.cofactor(0,1) =~ 12.0)
check(mat1.cofactor(0,2) =~ -46.0)
check(mat1.determinant =~ -196)
test "4x4 matrix determinant":
let mat1 = matrix(@[-2.0, -8.0, 3.0, 5.0, -3.0, 1.0, 7.0, 3.0, 1.0, 2.0, -9.0, 6.0, -6.0, 7.0, 7.0, -9.0], 4, 4)
check(mat1.cofactor(0,0) =~ 690.0)
check(mat1.cofactor(0,1) =~ 447.0)
check(mat1.cofactor(0,2) =~ 210.0)
check(mat1.cofactor(0,3) =~ 51.0)
check(mat1.determinant =~ -4071.0)
test "Matrix invertibility":
check(matrix(@[6.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 6.0, 4.0, -9.0, 3.0, -7.0, 9.0, 1.0, 7.0, -6.0], 4, 4).isInvertable)
check(not matrix(@[-4.0, 2.0, -2.0, -3.0, 9.0, 6.0, 2.0, 6.0, 0.0, -5.0, 1.0, -5.0, 0.0, 0.0, 0.0, 0.0], 4, 4).isInvertable)
test "Invert matrix":
let mat1 = matrix(@[-5.0, 2.0, 6.0, -8.0, 1.0, -5.0, 1.0, 8.0, 7.0, 7.0, -6.0, -7.0, 1.0, -3.0, 7.0, 4.0], 4, 4)
let mat1inverse = mat1.inverse
check(mat1.determinant =~ 532.0)
check(mat1.cofactor(2, 3) =~ -160.0)
check(mat1inverse[3,2] =~ -160.0/532.0)
check(mat1.cofactor(3, 2) =~ 105.0)
check(mat1inverse[2,3] =~ 105.0/532.0)
check(mat1inverse == matrix(@[0.2180451127819549, 0.4511278195488722, 0.2406015037593985, -0.04511278195488722, -0.8082706766917294, -1.456766917293233, -0.443609022556391, 0.5206766917293233, -0.07894736842105263, -0.2236842105263158, -0.05263157894736842, 0.1973684210526316, -0.5225563909774437, -0.8139097744360902, -0.3007518796992481, 0.306390977443609], 4, 4))
test "Invert matrix 2":
let mat1 = matrix(@[8.0, -5.0, 9.0, 2.0, 7.0, 5.0, 6.0, 1.0, -6.0, 0.0, 9.0, 6.0, -3.0, 0.0, -9.0, -4.0], 4, 4)
check(mat1.inverse == matrix(@[-0.1538461538461539, -0.1538461538461539, -0.282051282051282, -0.5384615384615384, -0.07692307692307693, 0.1230769230769231, 0.02564102564102564, 0.03076923076923077, 0.358974358974359, 0.358974358974359, 0.4358974358974359, 0.9230769230769231, -0.6923076923076923, -0.6923076923076923, -0.7692307692307693, -1.923076923076923], 4, 4))
test "Invert matrix 3":
let mat1 = matrix(@[9.0, 3.0, 0.0, 9.0, -5.0, -2.0, -6.0, -3.0, -4.0, 9.0, 6.0, 4.0, -7.0, 6.0, 6.0, 2.0], 4, 4)
check(mat1.inverse == matrix(@[-0.04074074074074074, -0.07777777777777778, 0.1444444444444444, -0.2222222222222222, -0.07777777777777778, 0.03333333333333333, 0.3666666666666666, -0.3333333333333333, -0.02901234567901234, -0.1462962962962963, -0.1092592592592593, 0.1296296296296296, 0.1777777777777778, 0.06666666666666667, -0.2666666666666667, 0.3333333333333333], 4, 4))
test "Multiply product by inverse":
let mat1 = matrix(@[3.0, -9.0, 7.0, 3.0, 3.0, -8.0, 2.0, -9.0, -4.0, 4.0, 4.0, 1.0, -6.0, 5.0, -1.0, 1.0], 4, 4)
let mat2 = matrix(@[8.0, 2.0, 2.0, 2.0, 3.0, -1.0, 7.0, 0.0, 7.0, 0.0, 5.0, 4.0, 6.0, -2.0, 0.0, 5.0], 4, 4)
let mat3 = mat1 * mat2
check(mat3 * mat2.inverse == mat1)