From 8cd001568a4e3f7b82f1d86764269fbecfb82ef6 Mon Sep 17 00:00:00 2001 From: Anton Blanchard Date: Sat, 8 Feb 2020 14:36:52 +1100 Subject: [PATCH 1/2] Fix signed multiply The upper bits of signed multiplications was all wrong. Fix it. Signed-off-by: Anton Blanchard --- src/main/scala/SimpleMultiplier.scala | 39 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/main/scala/SimpleMultiplier.scala b/src/main/scala/SimpleMultiplier.scala index 1485696..3af7250 100644 --- a/src/main/scala/SimpleMultiplier.scala +++ b/src/main/scala/SimpleMultiplier.scala @@ -17,7 +17,7 @@ class SimpleMultiplier(val bits: Int) extends Module { val out = Output(Valid(UInt(bits.W))) }) - val a = Reg(UInt((2*bits).W)) + val a = Reg(UInt(bits.W)) val b = Reg(UInt((2*bits).W)) val is32bit = Reg(Bool()) val high = Reg(Bool()) @@ -28,22 +28,35 @@ class SimpleMultiplier(val bits: Int) extends Module { io.in.ready := !busy when (io.in.valid && !busy) { + val aSignExtend = WireDefault(io.in.bits.a) + val bSignExtend = WireDefault(io.in.bits.b) + + /* Sign or zero extend 32 bit values */ when (io.in.bits.is32bit) { when (io.in.bits.signed) { - a := io.in.bits.a.signExtend(32, 2*bits) - b := io.in.bits.b.signExtend(32, 2*bits) + aSignExtend := io.in.bits.a.signExtend(32, bits) + bSignExtend := io.in.bits.b.signExtend(32, bits) } .otherwise { - a := io.in.bits.a(31, 0) - b := io.in.bits.b(31, 0) + aSignExtend := io.in.bits.a.zeroExtend(32, bits) + bSignExtend := io.in.bits.b.zeroExtend(32, bits) + } + } + + when (io.in.bits.signed) { + /* + * We always want a positive value in a, so take the two's complement + * of both args if a is negative + */ + when (aSignExtend(bits-1)) { + a := -aSignExtend + b := -(bSignExtend.signExtend(bits, 2*bits)) + } .otherwise { + a := aSignExtend + b := bSignExtend.signExtend(bits, 2*bits) } } .otherwise { - when (io.in.bits.signed) { - a := io.in.bits.a.signExtend(64, 2*bits) - b := io.in.bits.b.signExtend(64, 2*bits) - } .otherwise { - a := io.in.bits.a - b := io.in.bits.b - } + a := aSignExtend + b := bSignExtend } is32bit := io.in.bits.is32bit @@ -62,7 +75,7 @@ class SimpleMultiplier(val bits: Int) extends Module { count := count + 1.U } - val result = WireDefault(res) + val result = WireDefault(res(63, 0)) when (high) { when (is32bit) { result := res(63, 32) ## res(63, 32) From dc8b74d51d52ebff79dfb68f15482f76e1652600 Mon Sep 17 00:00:00 2001 From: Anton Blanchard Date: Sat, 8 Feb 2020 23:01:44 +1100 Subject: [PATCH 2/2] Improve Multiplier tests Signed-off-by: Anton Blanchard --- src/test/scala/SimpleMultiplier.scala | 105 ++++++++++++++------------ 1 file changed, 56 insertions(+), 49 deletions(-) diff --git a/src/test/scala/SimpleMultiplier.scala b/src/test/scala/SimpleMultiplier.scala index 4fac7e8..d9b695c 100644 --- a/src/test/scala/SimpleMultiplier.scala +++ b/src/test/scala/SimpleMultiplier.scala @@ -24,66 +24,73 @@ class SimpleMultiplierUnitTester extends FlatSpec with ChiselScalatestTester wit multHigh(aSigned, bSigned) } + def mult32(a: BigInt, b: BigInt): BigInt = { + val a32 = a & BigInt("ffffffff", 16) + val b32 = b & BigInt("ffffffff", 16) + + (a32 * b32) + } + + def multHigh32(a: BigInt, b: BigInt): BigInt = { + val a32 = a & BigInt("ffffffff", 16) + val b32 = b & BigInt("ffffffff", 16) + + val m = ((a32 * b32) >> 32) & BigInt("ffffffff", 16) + (m << 32) | m + } + + def multHighSigned32(a: BigInt, b: BigInt): BigInt = { + val a32 = a & BigInt("ffffffff", 16) + val b32 = b & BigInt("ffffffff", 16) + val aSigned = if (a32.testBit(31)) (BigInt("ffffffff", 16) << 32) + a32 else a32 + val bSigned = if (b32.testBit(31)) (BigInt("ffffffff", 16) << 32) + b32 else b32 + + val m = ((aSigned * bSigned) >> 32) & BigInt("ffffffff", 16) + (m << 32) | m + } + + def runOneTest(m: SimpleMultiplier, mult: (BigInt, BigInt) => BigInt) = { + for ((x, y) <- tests) { + while (m.io.in.ready.peek().litToBoolean == false) { + m.clock.step(1) + } + + m.io.in.bits.a.poke(x.U) + m.io.in.bits.b.poke(y.U) + m.io.in.valid.poke(true.B) + m.clock.step(1) + m.io.in.valid.poke(false.B) + + while (m.io.out.valid.peek().litToBoolean == false) { + m.clock.step(1) + } + + m.io.out.bits.expect(mult(x, y).U) + } + } + it should "pass a unit test" in { test(new SimpleMultiplier(64)) { m => - for ((x, y) <- tests) { - while (m.io.in.ready.peek().litToBoolean == false) { - m.clock.step(1) - } - m.io.in.bits.a.poke(x.U) - m.io.in.bits.b.poke(y.U) - m.io.in.valid.poke(true.B) - m.clock.step(1) - m.io.in.valid.poke(false.B) - - while (m.io.out.valid.peek().litToBoolean == false) { - m.clock.step(1) - } - - m.io.out.bits.expect(mult(x, y).U) - } + runOneTest(m, mult) m.io.in.bits.high.poke(true.B) - - for ((x, y) <- tests) { - while (m.io.in.ready.peek().litToBoolean == false) { - m.clock.step(1) - } - - m.io.in.bits.a.poke(x.U) - m.io.in.bits.b.poke(y.U) - m.io.in.valid.poke(true.B) - m.clock.step(1) - m.io.in.valid.poke(false.B) - - while (m.io.out.valid.peek().litToBoolean == false) { - m.clock.step(1) - } - - m.io.out.bits.expect(multHigh(x, y).U) - } + runOneTest(m, multHigh) m.io.in.bits.signed.poke(true.B) + runOneTest(m, multHighSigned) - for ((x, y) <- tests) { - while (m.io.in.ready.peek().litToBoolean == false) { - m.clock.step(1) - } + m.io.in.bits.signed.poke(false.B) + m.io.in.bits.high.poke(false.B) - m.io.in.bits.a.poke(x.U) - m.io.in.bits.b.poke(y.U) - m.io.in.valid.poke(true.B) - m.clock.step(1) - m.io.in.valid.poke(false.B) + m.io.in.bits.is32bit.poke(true.B) + runOneTest(m, mult32) - while (m.io.out.valid.peek().litToBoolean == false) { - m.clock.step(1) - } - - m.io.out.bits.expect(multHighSigned(x, y).U) - } + m.io.in.bits.high.poke(true.B) + runOneTest(m, multHigh32) + m.io.in.bits.signed.poke(true.B) + runOneTest(m, multHighSigned32) } } }