Skip to content

Commit

Permalink
Fixed CSR Instruction Handlers and Use New Register Macros (#65)
Browse files Browse the repository at this point in the history
Set of fixes needed to get the `rv64mi-p-csr` arch test to pass.
- Use new register macros in all instruction handlers
- Set some CSR read-only fields at boot based on the configured XLEN
- Added a way to override CSR register values through a JSON file
- Made several fixes to the CSR instruction handlers
- Improved Instruction Logger to better support exceptions and to print
CSR values for CSR instructions
  • Loading branch information
kathlenemagnus authored Feb 12, 2025
1 parent d8ca794 commit 2c5d2a9
Show file tree
Hide file tree
Showing 25 changed files with 804 additions and 640 deletions.
2 changes: 2 additions & 0 deletions arch/default_csr_values.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{
}
18 changes: 16 additions & 2 deletions arch/register_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
const auto mask = ((1ULL << (high_bit - low_bit + 1)) - 1) << low_bit; \
csr_value &= ~mask; \
\
uint64_t new_field_value = field_value; \
new_field_value <<= low_bit; \
const uint64_t new_field_value = field_value << low_bit; \
csr_value |= new_field_value; \
\
state->getCsrRegister(reg_ident)->write(csr_value); \
Expand All @@ -72,3 +71,18 @@
#define PEEK_CSR_REG(reg_ident) READ_CSR_REG(reg_ident)

#define POKE_CSR_REG(reg_ident, reg_value) state->getCsrRegister(reg_ident)->dmiWrite(reg_value);

#define POKE_CSR_FIELD(reg_ident, field_name, field_value) \
{ \
auto csr_value = state->getCsrRegister(reg_ident)->dmiRead<uint64_t>(); \
\
const auto low_bit = atlas::getCsrBitRange(reg_ident, #field_name).first; \
const auto high_bit = atlas::getCsrBitRange(reg_ident, #field_name).second; \
const auto mask = ((1ULL << (high_bit - low_bit + 1)) - 1) << low_bit; \
csr_value &= ~mask; \
\
const uint64_t new_field_value = (uint64_t)field_value << low_bit; \
csr_value |= new_field_value; \
\
state->getCsrRegister(reg_ident)->dmiWrite(csr_value); \
}
62 changes: 38 additions & 24 deletions core/AtlasInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,44 @@

namespace atlas
{

template <mavis::InstMetaData::OperandFieldID OperandFieldId>
sparta::Register* getSpartaReg(AtlasState* state,
const mavis::OperandInfo::ElementList & operand_list)
const mavis::OperandInfo::Element*
getOperand(const mavis::OperandInfo::ElementList & operand_list)
{
const auto operand = std::find_if(operand_list.begin(), operand_list.end(),
[](const mavis::OperandInfo::Element & operand)
{ return operand.field_id == OperandFieldId; });
const auto operand_it = std::find_if(operand_list.begin(), operand_list.end(),
[](const mavis::OperandInfo::Element & operand)
{ return operand.field_id == OperandFieldId; });

// Instruction does not have this operand type
if (operand == operand_list.end())
if (operand_it == operand_list.end())
{
return nullptr;
}
else
{
return &(*operand_it);
}
}

switch (operand->operand_type)
sparta::Register* getSpartaReg(AtlasState* state, const mavis::OperandInfo::Element* operand)
{
if (operand)
{
case mavis::InstMetaData::OperandTypes::WORD:
case mavis::InstMetaData::OperandTypes::LONG:
return state->getIntRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::SINGLE:
case mavis::InstMetaData::OperandTypes::DOUBLE:
case mavis::InstMetaData::OperandTypes::QUAD:
return state->getFpRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::VECTOR:
return state->getVecRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::NONE:
sparta_assert(false, "Invalid Mavis Operand Type!");
switch (operand->operand_type)
{
case mavis::InstMetaData::OperandTypes::WORD:
case mavis::InstMetaData::OperandTypes::LONG:
return state->getIntRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::SINGLE:
case mavis::InstMetaData::OperandTypes::DOUBLE:
case mavis::InstMetaData::OperandTypes::QUAD:
return state->getFpRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::VECTOR:
return state->getVecRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::NONE:
sparta_assert(false, "Invalid Mavis Operand Type!");
}
}

return nullptr;
Expand All @@ -40,12 +51,15 @@ namespace atlas
opcode_info_(opcode_info),
extractor_info_(extractor_info),
opcode_size_(((getOpcode() & 0x3) != 0x3) ? 2 : 4),
rs1_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RS1>(
state, opcode_info->getSourceOpInfoList())),
rs2_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RS2>(
state, opcode_info->getSourceOpInfoList())),
rd_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RD>(
state, opcode_info->getDestOpInfoList())),
rs1_info_(getOperand<mavis::InstMetaData::OperandFieldID::RS1>(
opcode_info->getSourceOpInfoList())),
rs2_info_(getOperand<mavis::InstMetaData::OperandFieldID::RS2>(
opcode_info->getSourceOpInfoList())),
rd_info_(
getOperand<mavis::InstMetaData::OperandFieldID::RD>(opcode_info->getDestOpInfoList())),
rs1_reg_(getSpartaReg(state, rs1_info_)),
rs2_reg_(getSpartaReg(state, rs2_info_)),
rd_reg_(getSpartaReg(state, rd_info_)),
inst_action_group_(extractor_info_->inst_action_group_)
{
}
Expand Down
64 changes: 48 additions & 16 deletions core/AtlasInst.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,21 @@ namespace atlas

uint64_t getImmediate() const
{
sparta_assert(opcode_info_->hasImmediate(), "Failed to get immediate value!");
sparta_assert(hasImmediate(), "Failed to get immediate value!");
return opcode_info_->getImmediate();
}

bool hasCsr() const
{
return opcode_info_->isInstType(mavis::OpcodeInfo::InstructionTypes::CSR);
}

uint32_t getCsr() const
{
sparta_assert(hasCsr(), "Failed to get CSR!");
return opcode_info_->getSpecialField(mavis::OpcodeInfo::SpecialField::CSR);
}

template <class T, uint32_t imm_size> T getSignExtendedImmediate() const
{
sparta_assert(opcode_info_->hasImmediate(), "Failed to get immediate value!");
Expand All @@ -56,29 +67,47 @@ namespace atlas

uint32_t getOpcodeSize() const { return opcode_size_; }

sparta::Register* getRs1()
uint32_t getRs1() const
{
sparta_assert(rs1_info_, "Operand RS1 is a nullptr! " << *this);
return rs1_info_->field_value;
}

uint32_t getRs2() const
{
sparta_assert(rs2_info_, "Operand RS2 is a nullptr! " << *this);
return rs2_info_->field_value;
}

uint32_t getRd() const
{
sparta_assert(rd_info_, "Operand RD is a nullptr! " << *this);
return rd_info_->field_value;
}

sparta::Register* getRs1Reg() const
{
sparta_assert(rs1_, "Operand RS1 is a nullptr! " << *this);
return rs1_;
sparta_assert(rs1_reg_, "Operand RS1 is a nullptr! " << *this);
return rs1_reg_;
}

sparta::Register* getRs2()
sparta::Register* getRs2Reg() const
{
sparta_assert(rs2_, "Operand RS2 is a nullptr! " << *this);
return rs2_;
sparta_assert(rs2_reg_, "Operand RS2 is a nullptr! " << *this);
return rs2_reg_;
}

sparta::Register* getRd()
sparta::Register* getRdReg() const
{
sparta_assert(rd_, "Operand RD is a nullptr! " << *this);
return rd_;
sparta_assert(rd_reg_, "Operand RD is a nullptr! " << *this);
return rd_reg_;
}

bool hasRs1() const { return rs1_ != nullptr; }
bool hasRs1() const { return rs1_reg_ != nullptr; }

bool hasRs2() const { return rs2_ != nullptr; }
bool hasRs2() const { return rs2_reg_ != nullptr; }

bool hasRd() const { return rd_ != nullptr; }
bool hasRd() const { return rd_reg_ != nullptr; }

ActionGroup* getActionGroup() { return &inst_action_group_; }

Expand All @@ -102,9 +131,12 @@ namespace atlas
Addr next_pc_;

// Registers
sparta::Register* rs1_;
sparta::Register* rs2_;
sparta::Register* rd_;
const mavis::OperandInfo::Element* rs1_info_;
const mavis::OperandInfo::Element* rs2_info_;
const mavis::OperandInfo::Element* rd_info_;
sparta::Register* rs1_reg_;
sparta::Register* rs2_reg_;
sparta::Register* rd_reg_;

ActionGroup inst_action_group_;
bool unimplemented_ = false;
Expand Down
89 changes: 66 additions & 23 deletions core/AtlasState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ namespace atlas
supported_isa_string_(std::string("rv" + std::to_string(xlen_) + "g_zicsr_zifencei")),
isa_file_path_(p->isa_file_path),
uarch_file_path_(p->uarch_file_path),
csr_values_json_(p->csr_values),
extension_manager_(mavis::extension_manager::riscv::RISCVExtensionManager::fromISA(
supported_isa_string_, isa_file_path_ + std::string("/riscv_isa_spec.json"),
isa_file_path_)),
Expand Down Expand Up @@ -152,6 +153,32 @@ namespace atlas
}
}

void AtlasState::onBindTreeLate_()
{
// Write initial values to CSR registers
std::ifstream ifs(csr_values_json_);
nlohmann::json csr_values_json = nlohmann::json::parse(ifs);
for (const auto & [csr_name, hex_str] : csr_values_json.items())
{
sparta::Register* csr_reg = findRegister(csr_name);
if (csr_reg)
{
sparta_assert(csr_reg->getGroupNum()
== (sparta::RegisterBase::group_num_type)RegType::CSR,
"Provided initial value for not-CSR register: " << csr_name);
const uint64_t csr_val = std::stoull(std::string(hex_str), nullptr, 16);
std::cout << csr_name << ": " << HEX16(csr_val) << std::endl;
csr_reg->dmiWrite(csr_val);
}
else
{
std::cout
<< "WARNING: Provided initial value for CSR register that does not exist! "
<< csr_name << std::endl;
}
}
}

ActionGroup* AtlasState::preExecute_(AtlasState* state)
{
// TODO cnyce: Package up all rs1/rs2/rd registers, pc, opcode, etc.
Expand Down Expand Up @@ -466,7 +493,11 @@ namespace atlas
{
// Set PC
pc_ = next_pc_;
next_pc_ = 0;
DLOG("PC: 0x" << std::hex << pc_);

// Set Privilege Mode
priv_mode_ = next_priv_mode_;
DLOG("Privilege Mode: " << (uint32_t)priv_mode_);

// Increment instruction count
++sim_state_.inst_count;
Expand Down Expand Up @@ -524,12 +555,11 @@ namespace atlas

const AtlasInstPtr & insn = getCurrentInst();

int hart = getHartId();
std::string rs1_name = insn->hasRs1() ? insn->getRs1()->getName() : "";
uint64_t rs1_val = insn->hasRs1() ? insn->getRs1()->dmiRead<uint64_t>() : 0;
std::string rs2_name = insn->hasRs2() ? insn->getRs2()->getName() : "";
uint64_t rs2_val = insn->hasRs2() ? insn->getRs2()->dmiRead<uint64_t>() : 0;
std::string rd_name = insn->hasRd() ? insn->getRd()->getName() : "";
const std::string rs1_name = insn->hasRs1() ? insn->getRs1Reg()->getName() : "";
uint64_t rs1_val = insn->hasRs1() ? insn->getRs1Reg()->dmiRead<uint64_t>() : 0;
const std::string rs2_name = insn->hasRs2() ? insn->getRs2Reg()->getName() : "";
uint64_t rs2_val = insn->hasRs2() ? insn->getRs2Reg()->dmiRead<uint64_t>() : 0;
const std::string rd_name = insn->hasRd() ? insn->getRdReg()->getName() : "";

uint64_t rd_val_before = 0;
uint64_t rd_val_after = 0;
Expand All @@ -540,23 +570,23 @@ namespace atlas
Observer* obs = !observers_.empty() ? observers_.front().get() : nullptr;
sparta_assert(obs, "No observers enabled, nothing to debug!");

auto rd = insn->getRd();
auto rd_reg = insn->getRdReg();
rd_val_before = obs->getPrevRdValue();
rd_val_after = rd->dmiRead<uint64_t>();
rd_val_after = rd_reg->dmiRead<uint64_t>();

switch (rd->getGroupNum())
switch (rd_reg->getGroupNum())
{
case 0:
// INT
cosim_rd_val_after = cosim_query_->getIntRegValue(getHartId(), rd->getID());
cosim_rd_val_after = cosim_query_->getIntRegValue(getHartId(), rd_reg->getID());
break;
case 1:
// FP
cosim_rd_val_after = cosim_query_->getFpRegValue(getHartId(), rd->getID());
cosim_rd_val_after = cosim_query_->getFpRegValue(getHartId(), rd_reg->getID());
break;
case 2:
// VEC
cosim_rd_val_after = cosim_query_->getVecRegValue(getHartId(), rd->getID());
cosim_rd_val_after = cosim_query_->getVecRegValue(getHartId(), rd_reg->getID());
break;
case 3:
// Let this go to the default case assert. CSRs should not be written to in RD.
Expand Down Expand Up @@ -597,29 +627,42 @@ namespace atlas
int result_code = compareWithCoSimAndSync_();

std::unique_ptr<simdb::WorkerTask> task(
new InstSnapshotter(cosim_db_.get(), hart, rs1_name, rs1_val, rs2_name, rs2_val,
new InstSnapshotter(cosim_db_.get(), hart_id_, rs1_name, rs1_val, rs2_name, rs2_val,
rd_name, rd_val_before, rd_val_after, cosim_rd_val_after, has_imm,
imm, disasm, mnemonic, opcode, pc, priv, result_code));

cosim_db_->getTaskQueue()->addWorkerTask(std::move(task));

if (!all_csr_vals.empty())
{
task.reset(new CsrValuesSnapshotter(cosim_db_.get(), hart, pc, all_csr_vals));
task.reset(new CsrValuesSnapshotter(cosim_db_.get(), hart_id_, pc, all_csr_vals));
cosim_db_->getTaskQueue()->addWorkerTask(std::move(task));
}
}

uint64_t AtlasState::getMStatusInitialValue(const AtlasState* state, const uint64_t xlen_val)
void AtlasState::boot()
{
// TODO cnyce
(void)state;
(void)xlen_val;
return 42949672960;
}
std::cout << "Booting hartid " << std::dec << hart_id_ << std::endl;
{
AtlasState* state = this;

POKE_CSR_REG(MHARTID, hart_id_);

// TODO: Initialize MISA CSR with XLEN and enabled extensions
const uint64_t xlen_val = (xlen_ == 64) ? 2 : 1;
POKE_CSR_FIELD(MISA, mxl, xlen_val);

// Initialize MSTATUS/STATUS with User and Supervisor mode XLEN
POKE_CSR_FIELD(MSTATUS, uxl, xlen_val);
POKE_CSR_FIELD(MSTATUS, sxl, xlen_val);
POKE_CSR_FIELD(SSTATUS, uxl, xlen_val);

std::cout << state->getCsrRegister(MHARTID) << std::endl;
std::cout << state->getCsrRegister(MISA) << std::endl;
std::cout << state->getCsrRegister(MSTATUS) << std::endl;
std::cout << state->getCsrRegister(SSTATUS) << std::endl;
}

void AtlasState::postInit()
{
if (interactive_mode_)
{
auto observer = std::make_unique<SimController>();
Expand Down
Loading

0 comments on commit 2c5d2a9

Please sign in to comment.