Branch Prediction Unit(BPU)

“I don’t make mistakes. I make prophecies which immediately turn out to be wrong.”

简介

本文将对香山的分支预测顶层模块BPU进行介绍,BPU为取值目标队列(FTQ)提供预测的取址目标,BPU主要包含如下逻辑:

  • 覆盖预测逻辑(后级预测对前级进行覆盖)
  • 流水线握手(各级之间以及BPU和FTQ之间)
  • 全局分支历史管理

基本结构及参数

BPU

分支预测单元采用多级混合预测架构,主要包含两个部分

  • Next Line Predictor (NLP):$\mu $BTB

  • Accurate Predictor (APD): FTB/TAGE-SC/ITTAGE/RAS

Input/Output

BasePredictor

IO.in

预测器的基本输入如下:

1
2
3
4
5
6
7
8
9
10
class BasePredictorInput (implicit p: Parameters) extends XSBundle with HasBPUConst {
def nInputs = 1

val s0_pc = UInt(VAddrBits.W)

val folded_hist = new AllFoldedHistories(foldedGHistInfos) // 经过折叠的历史
val ghist = UInt(HistoryLength.W) // 全局历史

val resp_in = Vec(nInputs, new BranchPredictionResp)
}

IO.out

1
2
3
4
class BasePredictorOutput (implicit p: Parameters) extends XSBundle with HasBPUConst {
val last_stage_meta = UInt(MaxMetaLength.W) // This is use by composer
val resp = new BranchPredictionResp // 预测器的结果,包含了各个stage的预测结果
}

IO

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BasePredictorIO (implicit p: Parameters) extends XSBundle with HasBPUConst {
val reset_vector = Input(UInt(PAddrBits.W))
val in = Flipped(DecoupledIO(new BasePredictorInput)) // BasePredictorInput
val out = Output(new BasePredictorOutput) // BasePredictorOutput

val ctrl = Input(new BPUCtrl) // Control signal

val s0_fire = Input(Bool()) // 驱动下一级stage
val s1_fire = Input(Bool())
val s2_fire = Input(Bool())
val s3_fire = Input(Bool())

val s2_redirect = Input(Bool()) // 重定向
val s3_redirect = Input(Bool())

val s1_ready = Output(Bool())
val s2_ready = Output(Bool())
val s3_ready = Output(Bool())

val update = Flipped(Valid(new BranchPredictionUpdate))
val redirect = Flipped(Valid(new BranchPredictionRedirect))
}

PredictorIO

1
2
3
4
5
6
7
class PredictorIO(implicit p: Parameters) extends XSBundle {
val bpu_to_ftq = new BpuToFtqIO() // 预测结果
val ftq_to_bpu = Flipped(new FtqToBpuIO()) // reverse input and output, update及redirect信息

val ctrl = Input(new BPUCtrl)
val reset_vector = Input(UInt(PAddrBits.W))
}

预测器逻辑

握手逻辑

BPU 的各个流水级都会连接 FTQ,一旦第一个预测流水级存在有效预测结果,或者后续预测流水级产生不同的预测结果,和 FTQ 的握手信号有效位都会置高。

1
2
3
4
5
6
7
8
9
10
11
// 握手逻辑
io.bpu_to_ftq.resp.valid :=
s1_valid && s2_components_ready && s2_ready ||
s2_fire && s2_redirect || // different predict result
s3_fire && s3_redirect
io.bpu_to_ftq.resp.bits := BpuToFtqBundle(predictors.io.out.resp)
io.bpu_to_ftq.resp.bits.meta := predictors.io.out.last_stage_meta // TODO: change to lastStageMeta
io.bpu_to_ftq.resp.bits.s3.folded_hist := s3_folded_gh
io.bpu_to_ftq.resp.bits.s3.histPtr := s3_ghist_ptr
io.bpu_to_ftq.resp.bits.s3.lastBrNumOH := s3_last_br_num_oh
io.bpu_to_ftq.resp.bits.s3.afhob := s3_ahead_fh_oldest_bits

重定向

当不同预测阶段的预测结果不一致时,需要进行重定向:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
val s3_redirect_on_br_taken = resp.s3.full_pred.real_br_taken_mask().asUInt =/= previous_s2_pred.full_pred.real_br_taken_mask().asUInt
val s3_redirect_on_target = resp.s3.getTarget =/= previous_s2_pred.getTarget
val s3_redirect_on_jalr_target = resp.s3.full_pred.hit_taken_on_jalr && resp.s3.full_pred.jalr_target =/= previous_s2_pred.full_pred.jalr_target
val s3_redirect_on_fall_thru_error = resp.s3.fallThruError

s3_redirect := s3_fire && (
s3_redirect_on_br_taken || s3_redirect_on_target || s3_redirect_on_fall_thru_error
)

def preds_needs_redirect_vec(x: BranchPredictionBundle, y: BranchPredictionBundle) = {
VecInit(
x.getTarget =/= y.getTarget,
x.lastBrPosOH.asUInt =/= y.lastBrPosOH.asUInt,
x.taken =/= y.taken,
(x.taken && y.taken) && x.cfiIndex.bits =/= y.cfiIndex.bits,
)
}

流水线逻辑

不同的预测stage之间依靠寄存器保存中间结果,当某一级发射后,相应的PC值以及分支历史等信息进入下一级:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
val reset_vector = DelayN(io.reset_vector, 5)
val s0_pc = Wire(UInt(VAddrBits.W))
val s0_pc_reg = RegNext(s0_pc)
when (RegNext(RegNext(reset.asBool) && !reset.asBool)) {
s0_pc_reg := reset_vector
}
val s1_pc = RegEnable(s0_pc, s0_fire)
val s2_pc = RegEnable(s1_pc, s1_fire)
val s3_pc = RegEnable(s2_pc, s2_fire)

// folded history
val s0_folded_gh = Wire(new AllFoldedHistories(foldedGHistInfos))
val s0_folded_gh_reg = RegNext(s0_folded_gh, 0.U.asTypeOf(s0_folded_gh))
val s1_folded_gh = RegEnable(s0_folded_gh, 0.U.asTypeOf(s0_folded_gh), s0_fire)
val s2_folded_gh = RegEnable(s1_folded_gh, 0.U.asTypeOf(s0_folded_gh), s1_fire)
val s3_folded_gh = RegEnable(s2_folded_gh, 0.U.asTypeOf(s0_folded_gh), s2_fire)

// ???????
val s0_last_br_num_oh = Wire(UInt((numBr+1).W))
val s0_last_br_num_oh_reg = RegNext(s0_last_br_num_oh, 0.U)
val s1_last_br_num_oh = RegEnable(s0_last_br_num_oh, 0.U, s0_fire)
val s2_last_br_num_oh = RegEnable(s1_last_br_num_oh, 0.U, s1_fire)
val s3_last_br_num_oh = RegEnable(s2_last_br_num_oh, 0.U, s2_fire)

// ???????
val s0_ahead_fh_oldest_bits = Wire(new AllAheadFoldedHistoryOldestBits(foldedGHistInfos))
val s0_ahead_fh_oldest_bits_reg = RegNext(s0_ahead_fh_oldest_bits, 0.U.asTypeOf(s0_ahead_fh_oldest_bits))
val s1_ahead_fh_oldest_bits = RegEnable(s0_ahead_fh_oldest_bits, 0.U.asTypeOf(s0_ahead_fh_oldest_bits), s0_fire)
val s2_ahead_fh_oldest_bits = RegEnable(s1_ahead_fh_oldest_bits, 0.U.asTypeOf(s0_ahead_fh_oldest_bits), s1_fire)
val s3_ahead_fh_oldest_bits = RegEnable(s2_ahead_fh_oldest_bits, 0.U.asTypeOf(s0_ahead_fh_oldest_bits), s2_fire)

当发生重定向时,需要进行流水线冲刷

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// 流水线冲刷
s3_flush := redirect_req.valid // flush when redirect comes
s2_flush := s3_flush || s3_redirect
s1_flush := s2_flush || s2_redirect

s1_components_ready := predictors.io.s1_ready
s1_ready := s1_fire || !s1_valid
s0_fire := !reset.asBool && s1_components_ready && s1_ready // s0 chould fire when s1 ready
predictors.io.s0_fire := s0_fire

s2_components_ready := predictors.io.s2_ready
s2_ready := s2_fire || !s2_valid
s1_fire := s1_valid && s2_components_ready && s2_ready && io.bpu_to_ftq.resp.ready

s3_components_ready := predictors.io.s3_ready
s3_ready := s3_fire || !s3_valid
s2_fire := s2_valid && s3_components_ready && s3_ready

when (redirect_req.valid) { s1_valid := false.B }
.elsewhen(s0_fire) { s1_valid := true.B }
.elsewhen(s1_flush) { s1_valid := false.B }
.elsewhen(s1_fire) { s1_valid := false.B }

predictors.io.s1_fire := s1_fire

s2_fire := s2_valid // ?????????????

when(s2_flush) { s2_valid := false.B }
.elsewhen(s1_fire) { s2_valid := !s1_flush }
.elsewhen(s2_fire) { s2_valid := false.B }

predictors.io.s2_fire := s2_fire
predictors.io.s2_redirect := s2_redirect

s3_fire := s3_valid

when(s3_flush) { s3_valid := false.B }
.elsewhen(s2_fire) { s3_valid := !s2_flush }
.elsewhen(s3_fire) { s3_valid := false.B }

predictors.io.s3_fire := s3_fire
predictors.io.s3_redirect := s3_redirect

分支历史管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// History manage
// s1
val s1_possible_predicted_ghist_ptrs = (0 to numBr).map(s1_ghist_ptr - _.U)
val s1_predicted_ghist_ptr = Mux1H(resp.s1.lastBrPosOH, s1_possible_predicted_ghist_ptrs)

val s1_possible_predicted_fhs = (0 to numBr).map(i =>
s1_folded_gh.update(s1_ahead_fh_oldest_bits, s1_last_br_num_oh, i, resp.s1.brTaken && resp.s1.lastBrPosOH(i)))
val s1_predicted_fh = Mux1H(resp.s1.lastBrPosOH, s1_possible_predicted_fhs)

val s1_ahead_fh_ob_src = Wire(new AllAheadFoldedHistoryOldestBits(foldedGHistInfos))
s1_ahead_fh_ob_src.read(ghv, s1_ghist_ptr)

if (EnableGHistDiff) {
val s1_predicted_ghist = WireInit(getHist(s1_predicted_ghist_ptr).asTypeOf(Vec(HistoryLength, Bool())))
for (i <- 0 until numBr) {
when (resp.s1.shouldShiftVec(i)) {
s1_predicted_ghist(i) := resp.s1.brTaken && (i==0).B
}
}
when (s1_valid) {
s0_ghist := s1_predicted_ghist.asUInt
}
}

val s1_ghv_wens = (0 until HistoryLength).map(n =>
(0 until numBr).map(b => (s1_ghist_ptr).value === (CGHPtr(false.B, n.U) + b.U).value && resp.s1.shouldShiftVec(b) && s1_valid))
val s1_ghv_wdatas = (0 until HistoryLength).map(n =>
Mux1H(
(0 until numBr).map(b => (
(s1_ghist_ptr).value === (CGHPtr(false.B, n.U) + b.U).value && resp.s1.shouldShiftVec(b),
resp.s1.brTaken && resp.s1.lastBrPosOH(b+1)
))
)
)

结果保存

每个stage产生的结果保存在对应的Priority Mux中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
val npcGen   = new PhyPriorityMuxGenerator[UInt]
val foldedGhGen = new PhyPriorityMuxGenerator[AllFoldedHistories]
val ghistPtrGen = new PhyPriorityMuxGenerator[CGHPtr]
val lastBrNumOHGen = new PhyPriorityMuxGenerator[UInt]
val aheadFhObGen = new PhyPriorityMuxGenerator[AllAheadFoldedHistoryOldestBits]

npcGen.register(s1_valid, resp.s1.getTarget, Some("s1_target"), 4)
foldedGhGen.register(s1_valid, s1_predicted_fh, Some("s1_FGH"), 4)
ghistPtrGen.register(s1_valid, s1_predicted_ghist_ptr, Some("s1_GHPtr"), 4)
lastBrNumOHGen.register(s1_valid, resp.s1.lastBrPosOH.asUInt, Some("s1_BrNumOH"), 4)
aheadFhObGen.register(s1_valid, s1_ahead_fh_ob_src, Some("s1_AFHOB"), 4)
ghvBitWriteGens.zip(s1_ghv_wens).zipWithIndex.map{case ((b, w), i) =>
b.register(w.reduce(_||_), s1_ghv_wdatas(i), Some(s"s1_new_bit_$i"), 4)
}

npcGen.register(s2_redirect, resp.s2.getTarget, Some("s2_target"), 5)
foldedGhGen.register(s2_redirect, s2_predicted_fh, Some("s2_FGH"), 5)
ghistPtrGen.register(s2_redirect, s2_predicted_ghist_ptr, Some("s2_GHPtr"), 5)
lastBrNumOHGen.register(s2_redirect, resp.s2.lastBrPosOH.asUInt, Some("s2_BrNumOH"), 5)
aheadFhObGen.register(s2_redirect, s2_ahead_fh_ob_src, Some("s2_AFHOB"), 5)
ghvBitWriteGens.zip(s2_ghv_wens).zipWithIndex.map{case ((b, w), i) =>
b.register(w.reduce(_||_), s2_ghv_wdatas(i), Some(s"s2_new_bit_$i"), 5)
}

// ?????? why stage 3's prio is less than s2
npcGen.register(s3_redirect, resp.s3.getTarget, Some("s3_target"), 3)
foldedGhGen.register(s3_redirect, s3_predicted_fh, Some("s3_FGH"), 3)
ghistPtrGen.register(s3_redirect, s3_predicted_ghist_ptr, Some("s3_GHPtr"), 3)
lastBrNumOHGen.register(s3_redirect, resp.s3.lastBrPosOH.asUInt, Some("s3_BrNumOH"), 3)
aheadFhObGen.register(s3_redirect, s3_ahead_fh_ob_src, Some("s3_AFHOB"), 3)
ghvBitWriteGens.zip(s3_ghv_wens).zipWithIndex.map{case ((b, w), i) =>
b.register(w.reduce(_||_), s3_ghv_wdatas(i), Some(s"s3_new_bit_$i"), 3)
}

reset

1
2
3
when (RegNext(RegNext(reset.asBool) && !reset.asBool)) {    // reset release
s1_pc := reset_vector
}

参考文献

0%