method Main()
{
var m1: array2<int>, m2: array2<int>, m3: array2<int>;
m1 := new int[2,3];
m2 := new int[3,1];
m1[0,0] := 1; m1[0,1] := 2; m1[0,2] := 3;
m1[1,0] := 4; m1[1,1] := 5; m1[1,2] := 6;
m2[0,0] := 7;
m2[1,0] := 8;
m2[2,0] := 9;
m3 := Multiply'(m1, m2);
PrintMatrix(m1);
print "\n*\n";
PrintMatrix(m2);
print "\n=\n";
PrintMatrix(m3);
}
method PrintMatrix(m: array2<int>)
requires m != null
{
var i: nat := 0;
while (i < m.Length0)
{
var j: nat := 0;
print "\n";
while (j < m.Length1)
{
print m[i,j];
print "\t";
j := j + 1;
}
i := i + 1;
}
print "\n";
}
predicate AllowedToMultiply(m1: array2<int>, m2: array2<int>) {
m1 != null && m2 != null && m1.Length1 == m2.Length0
}
predicate AllowedToMultiplyInto(m1: array2<int>, m2: array2<int>, m3: array2<int>) {
AllowedToMultiply(m1,m2) &&
m3 != null && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1
}
predicate MM(m1: array2<int>, m2: array2<int>, m3: array2<int>)
{ // m3 is the result of multiplying the matrix m1 by the matrix m2
AllowedToMultiplyInto(m1,m2,m3) &&
forall i,j :: 0 <= i < m3.Length0 && 0 <= j < m3.Length1 ==> m3[i,j] == RowColumnProduct(m1,m2,i,j)
}
function RowColumnProduct(m1: array2<int>, m2: array2<int>, row: nat, column: nat): int
requires AllowedToMultiply(m1,m2)
requires row < m1.Length0 && column < m2.Length1
{
RowColumnProductFrom(m1, m2, row, column, 0)
}
function RowColumnProductFrom(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat): int
requires AllowedToMultiply(m1,m2)
requires row < m1.Length0 && column < m2.Length1
requires k <= m1.Length1
decreases m1.Length1 - k
{
if k == m1.Length1 then 0 else m1[row,k]*m2[k,column] + RowColumnProductFrom(m1, m2, row, column, k+1)
}
function RowColumnProductTo(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat,i:nat): int
requires AllowedToMultiply(m1,m2)
requires row < m1.Length0 && column < m2.Length1 && i < m1.Length1 == m2.Length0
requires k<=i
decreases i - k
{
if k == i then 0 else m1[row,k]*m2[k,column] + RowColumnProductTo(m1, m2, row, column, k+1,i)
}
function RowColumnProductForCount(m1: array2<int>, m2: array2<int>, row: nat, column: nat, n:nat): int
requires AllowedToMultiply(m1, m2)
requires row < m1.Length0 && column < m2.Length1 && n <= m1.Length1
{
if n == 0 then 0 else
RowColumnProductForCount(m1, m2, row, column, n-1) + m1[row,n-1]*m2[n-1,column]
}
predicate MMROW(m1: array2<int>, m2: array2<int>, m3: array2<int>, rown:nat)
requires AllowedToMultiplyInto(m1, m2, m3)
requires rown <= m1.Length0
{
forall r:nat,c:nat :: r < rown && c < m2.Length1 ==> m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1)
}
predicate MMCOL(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,coln:nat)
requires AllowedToMultiplyInto(m1, m2, m3)
requires row < m1.Length0 && coln <= m2.Length1
{
forall c:nat :: c < coln ==> m3[row,c] == RowColumnProductForCount(m1,m2,row,c,m1.Length1)
}
predicate MMI(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat,n:nat)
requires AllowedToMultiplyInto(m1, m2, m3)
requires row < m1.Length0 && col < m2.Length1 && n<=m1.Length1
{
m3[row,col] == RowColumnProductForCount(m1, m2, row, col, n)
}
method Multiply'(m1: array2<int>, m2: array2<int>) returns (m3: array2<int>)
requires AllowedToMultiply(m1, m2)
ensures MM(m1, m2, m3)
{
m3 := new int[m1.Length0, m2.Length1];
var row:nat := 0;
// loop over rows of m1
while(row < m1.Length0)
invariant row <= m1.Length0
invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn)
modifies m3
{
assert MMROW(m1, m2, m3, row);
// loop over coloums of m2
var col:nat := 0;
while(col < m2.Length1)
invariant col <= m2.Length1
invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn)
invariant forall n:nat :: n <= col ==> MMCOL(m1, m2, m3,row, n)
{
assert MMCOL(m1, m2, m3, row, col); //
// loop over elements of m1 row / m2 column
var i:nat := 0;
m3[row,col] := 0;
while(i < m1.Length1)
invariant i <= m1.Length1
invariant forall rn:nat :: rn < row ==> MMROW(m1, m2, m3, rn)
invariant forall c:nat :: c < col ==> MMCOL(m1, m2, m3, row, c)
invariant forall j:nat :: j <= i ==> MMI(m1, m2, m3, row, col, j)
{
assert MMI(m1, m2, m3, row, col, i);
m3[row,col]:= m3[row,col]+(m1[row,i]*m2[i,col]);
i := i+1;
assert MMI(m1, m2, m3, row, col, i);
}
assert MMI(m1, m2, m3, row, col, m1.Length1);
assert m3[row,col] == RowColumnProductForCount(m1,m2,row,col,m1.Length1);
col := col+1;
assert MMCOL(m1, m2, m3, row, col);
}
assert MMCOL(m1, m2, m3, row, m2.Length1);
row := row+1;
assert MMROW(m1, m2, m3, row);
}
assert MMROW(m1, m2, m3, m1.Length0);
MMROWImpliesMM(m1, m2, m3);
}
lemma MMROWImpliesMM(m1: array2<int>, m2: array2<int>, m3: array2<int>)
requires AllowedToMultiplyInto(m1,m2,m3)
requires MMROW(m1, m2, m3, m1.Length0)
ensures MM(m1, m2, m3)
{
assert forall r:nat,c:nat :: r < m1.Length0 && c < m2.Length1 ==> m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1);
forall r:nat,c:nat | r < m1.Length0 && c < m2.Length1
ensures m3[r,c] == RowColumnProduct(m1,m2,r,c)
{
assert m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1);
RowColumnProductForCountImpliesRowColumnProduct(m1, m2, m3, r, c);
}
assert forall r:nat,c:nat :: r < m3.Length0 && c < m3.Length1 ==> m3[r,c] == RowColumnProduct(m1,m2,r,c);
}
lemma RowColumnProductForCountImpliesRowColumnProduct(m1: array2<int>, m2: array2<int>, m3: array2<int>, r:nat, c:nat)
requires AllowedToMultiplyInto(m1,m2,m3)
requires r < m1.Length0 && c < m2.Length1;
requires m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1)
ensures m3[r,c] == RowColumnProduct(m1,m2,r,c)
{
assert RowColumnProduct(m1,m2,r,c) == RowColumnProductFrom(m1,m2,r,c,0);
var i:nat := 0;
var total := RowColumnProductForCount(m1,m2,r,c,m1.Length1);
while i < m1.Length1
invariant i <= m1.Length1
invariant total == RowColumnProductForCount(m1,m2,r,c,m1.Length1-i) + RowColumnProductFrom(m1,m2,r,c,m1.Length1-i)
{
i := i+1;
}
} |