From 8cd001568a4e3f7b82f1d86764269fbecfb82ef6 Mon Sep 17 00:00:00 2001 From: Anton Blanchard Date: Sat, 8 Feb 2020 14:36:52 +1100 Subject: [PATCH] Fix signed multiply The upper bits of signed multiplications was all wrong. Fix it. Signed-off-by: Anton Blanchard --- src/main/scala/SimpleMultiplier.scala | 39 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/main/scala/SimpleMultiplier.scala b/src/main/scala/SimpleMultiplier.scala index 1485696..3af7250 100644 --- a/src/main/scala/SimpleMultiplier.scala +++ b/src/main/scala/SimpleMultiplier.scala @@ -17,7 +17,7 @@ class SimpleMultiplier(val bits: Int) extends Module { val out = Output(Valid(UInt(bits.W))) }) - val a = Reg(UInt((2*bits).W)) + val a = Reg(UInt(bits.W)) val b = Reg(UInt((2*bits).W)) val is32bit = Reg(Bool()) val high = Reg(Bool()) @@ -28,22 +28,35 @@ class SimpleMultiplier(val bits: Int) extends Module { io.in.ready := !busy when (io.in.valid && !busy) { + val aSignExtend = WireDefault(io.in.bits.a) + val bSignExtend = WireDefault(io.in.bits.b) + + /* Sign or zero extend 32 bit values */ when (io.in.bits.is32bit) { when (io.in.bits.signed) { - a := io.in.bits.a.signExtend(32, 2*bits) - b := io.in.bits.b.signExtend(32, 2*bits) + aSignExtend := io.in.bits.a.signExtend(32, bits) + bSignExtend := io.in.bits.b.signExtend(32, bits) } .otherwise { - a := io.in.bits.a(31, 0) - b := io.in.bits.b(31, 0) + aSignExtend := io.in.bits.a.zeroExtend(32, bits) + bSignExtend := io.in.bits.b.zeroExtend(32, bits) + } + } + + when (io.in.bits.signed) { + /* + * We always want a positive value in a, so take the two's complement + * of both args if a is negative + */ + when (aSignExtend(bits-1)) { + a := -aSignExtend + b := -(bSignExtend.signExtend(bits, 2*bits)) + } .otherwise { + a := aSignExtend + b := bSignExtend.signExtend(bits, 2*bits) } } .otherwise { - when (io.in.bits.signed) { - a := io.in.bits.a.signExtend(64, 2*bits) - b := io.in.bits.b.signExtend(64, 2*bits) - } .otherwise { - a := io.in.bits.a - b := io.in.bits.b - } + a := aSignExtend + b := bSignExtend } is32bit := io.in.bits.is32bit @@ -62,7 +75,7 @@ class SimpleMultiplier(val bits: Int) extends Module { count := count + 1.U } - val result = WireDefault(res) + val result = WireDefault(res(63, 0)) when (high) { when (is32bit) { result := res(63, 32) ## res(63, 32)