8000 [jit] Add caching to `LlvmTypeConverter::ConvertToLlvmType`. by copybara-service[bot] · Pull Request #2519 · google/xls · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[jit] Add caching to LlvmTypeConverter::ConvertToLlvmType. #2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions xls/jit/function_base_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ enum class AllocationKind : uint8_t {
// functions.
class BufferAllocator {
public:
explicit BufferAllocator(const LlvmTypeConverter* type_converter)
explicit BufferAllocator(LlvmTypeConverter* type_converter)
: type_converter_(type_converter) {}

void SetAllocationKind(Node* node, AllocationKind kind) {
Expand Down Expand Up @@ -349,7 +349,7 @@ class BufferAllocator {
offset, node_size, current_offset_);
}

const LlvmTypeConverter* type_converter_;
LlvmTypeConverter* type_converter_;
absl::flat_hash_map<Node*, int64_t> temp_block_offsets_;
int64_t current_offset_ = 0;
int64_t alignment_ = 1;
Expand Down Expand Up @@ -965,8 +965,7 @@ absl::StatusOr<PartitionedFunction> BuildFunctionInternal(
// bytes when unpacking values.
absl::Status UnpackValue(llvm::Value* packed_buffer,
llvm::Value* unpacked_buffer, Type* xls_type,
int64_t bit_offset,
const LlvmTypeConverter& type_converter,
int64_t bit_offset, LlvmTypeConverter& type_converter,
llvm::IRBuilder<>* builder) {
switch (xls_type->kind()) {
case TypeKind::kBits: {
Expand Down Expand Up @@ -1056,7 +1055,7 @@ absl::Status UnpackValue(llvm::Value* packed_buffer,
// value.
absl::Status PackValue(llvm::Value* unpacked_buffer, llvm::Value* packed_buffer,
Type* xls_type, int64_t bit_offset,
const LlvmTypeConverter& type_converter,
LlvmTypeConverter& type_converter,
llvm::IRBuilder<>* builder) {
if (xls_type->GetFlatBitCount() == 0) {
return absl::OkStatus();
Expand Down
30 changes: 17 additions & 13 deletions xls/jit/llvm_type_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ int64_t LlvmTypeConverter::GetLlvmBitCount(int64_t xls_bit_count) const {
return int64_t{1} << CeilOfLog2(xls_bit_count);
}

llvm::Type* LlvmTypeConverter::ConvertToLlvmType(const Type* xls_type) const {
llvm::Type* LlvmTypeConverter::ConvertToLlvmType(const Type* xls_type) {
if (auto it = type_cache_.find(xls_type); it != type_cache_.end()) {
return it->second;
}
llvm::Type* llvm_type;
if (xls_type->IsBits()) {
llvm_type = llvm::IntegerType::get(
Expand Down Expand Up @@ -88,6 +91,7 @@ llvm::Type* LlvmTypeConverter::ConvertToLlvmType(const Type* xls_type) const {
LOG(FATAL) << absl::StrCat("Type not supported for LLVM conversion: %s",
xls_type->ToString());
}
type_cache_.emplace(xls_type, llvm_type);
return llvm_type;
}

Expand All @@ -105,7 +109,7 @@ int64_t LlvmTypeConverter::GetPackedTypeByteSize(const Type* type) const {
}

absl::StatusOr<llvm::Constant*> LlvmTypeConverter::ToLlvmConstant(
const Type* type, const Value& value) const {
const Type* type, const Value& value) {
return ToLlvmConstant(ConvertToLlvmType(type), value);
}

Expand Down Expand Up @@ -182,27 +186,26 @@ absl::StatusOr<llvm::Constant*> LlvmTypeConverter::ToIntegralConstant(
return llvm::ConstantInt::get(type, bits);
}

int64_t LlvmTypeConverter::GetTypeByteSize(const Type* type) const {
int64_t LlvmTypeConverter::GetTypeByteSize(const Type* type) {
return data_layout_.getTypeAllocSize(ConvertToLlvmType(type)).getFixedValue();
}

int64_t LlvmTypeConverter::GetTypeAbiAlignment(const Type* type) const {
int64_t LlvmTypeConverter::GetTypeAbiAlignment(const Type* type) {
return data_layout_.getABITypeAlign(ConvertToLlvmType(type)).value();
}
int64_t LlvmTypeConverter::GetTypePreferredAlignment(const Type* type) const {
int64_t LlvmTypeConverter::GetTypePreferredAlignment(const Type* type) {
return data_layout_.getPrefTypeAlign(ConvertToLlvmType(type)).value();
}

TypeBufferMetadata LlvmTypeConverter::GetTypeBufferMetadata(
const Type* type) const {
TypeBufferMetadata LlvmTypeConverter::GetTypeBufferMetadata(const Type* type) {
return TypeBufferMetadata{
.size = GetTypeByteSize(type),
.preferred_alignment = GetTypePreferredAlignment(type),
.abi_alignment = GetTypeAbiAlignment(type),
.packed_size = GetPackedTypeByteSize(type)};
}

int64_t LlvmTypeConverter::AlignFor(const Type* type, int64_t offset) const {
int64_t LlvmTypeConverter::AlignFor(const Type* type, int64_t offset) {
llvm::Align alignment =
data_layout_.getPrefTypeAlign(ConvertToLlvmType(type));
return llvm::alignTo(offset, alignment);
Expand All @@ -220,7 +223,7 @@ llvm::Value* LlvmTypeConverter::GetToken() const {

llvm::Value* LlvmTypeConverter::AsSignedValue(
llvm::Value* value, Type* xls_type, llvm::IRBuilder<>& builder,
std::optional<llvm::Type*> dest_type) const {
std::optional<llvm::Type*> dest_type) {
CHECK(xls_type->IsBits());
int64_t xls_bit_count = xls_type->AsBitsOrDie()->bit_count();
int64_t llvm_bit_count = GetLlvmBitCount(xls_bit_count);
Expand All @@ -246,7 +249,7 @@ llvm::Value* LlvmTypeConverter::AsSignedValue(
}

llvm::Value* LlvmTypeConverter::PaddingMask(Type* xls_type,
llvm::IRBuilder<>& builder) const {
llvm::IRBuilder<>& builder) {
CHECK(xls_type->IsBits());
int64_t xls_bit_count = xls_type->AsBitsOrDie()->bit_count();
int64_t llvm_bit_count = GetLlvmBitCount(xls_type->AsBitsOrDie());
Expand All @@ -261,12 +264,13 @@ llvm::Value* LlvmTypeConverter::PaddingMask(Type* xls_type,
}

llvm::Value* LlvmTypeConverter::InvertedPaddingMask(
Type* xls_type, llvm::IRBuilder<>& builder) const {
Type* xls_type, llvm::IRBuilder<>& builder) {
return builder.CreateNot(PaddingMask(xls_type, builder));
}

llvm::Value* LlvmTypeConverter::ClearPaddingBits(
llvm::Value* value, Type* xls_type, llvm::IRBuilder<>& builder) const {
llvm::Value* LlvmTypeConverter::ClearPaddingBits(llvm::Value* value,
Type* xls_type,
llvm::IRBuilder<>& builder) {
if (!xls_type->IsBits()) {
// TODO(meheff): Handle non-bits types.
return value;
Expand Down
26 changes: 13 additions & 13 deletions xls/jit/llvm_type_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class LlvmTypeConverter {
LlvmTypeConverter(llvm::LLVMContext* context,
const llvm::DataLayout& data_layout);

llvm::Type* ConvertToLlvmType(const Type* type) const;
llvm::Type* ConvertToPointerToLlvmType(const Type* type) const {
llvm::Type* ConvertToLlvmType(const Type* type);
llvm::Type* ConvertToPointerToLlvmType(const Type* type) {
return llvm::PointerType::get(ConvertToLlvmType(type), 0);
}

Expand All @@ -61,7 +61,7 @@ class LlvmTypeConverter {
absl::StatusOr<llvm::Constant*> ToLlvmConstant(llvm::Type* type,
const Value& value) const;
absl::StatusOr<llvm::Constant*> ToLlvmConstant(const Type* ty 8000 pe,
const Value& value) const;
const Value& value);

// Returns a constant zero of the given type.
static llvm::Constant* ZeroOfType(llvm::Type* type);
Expand All @@ -72,22 +72,22 @@ class LlvmTypeConverter {
// instead of the three that the flat bit count would suggest. The type width
// rules aren't necessarily immediately obvious, but fortunately the
// DataLayout object can handle ~all of the work for us.
int64_t GetTypeByteSize(const Type* type) const;
int64_t GetTypeByteSize(const Type* type);

// Returns the preferred alignment for the given type.
int64_t GetTypePreferredAlignment(const Type* type) const;
int64_t GetTypePreferredAlignment(const Type* type);

// Returns the alignment requirement for the given type.
int64_t GetTypeAbiAlignment(const Type* type) const;
int64_t GetTypeAbiAlignment(const Type* type);

TypeBufferMetadata GetTypeBufferMetadata(const Type* type) const;
TypeBufferMetadata GetTypeBufferMetadata(const Type* type);

// Returns the next position (starting from offset) where LLVM would consider
// an object of the given type to have ended; specifically, the next position
// that matches the greater of the stack alignment and the type's preferred
// alignment. As above, the rules aren't immediately obvious, but the
// DataLayout object takes care of the details.
int64_t AlignFor(const Type* type, int64_t offset) const;
int64_t AlignFor(const Type* type, int64_t offset);

// Returns a new Value representing the LLVM form of a Token.
llvm::Value* GetToken() const;
Expand Down Expand Up @@ -117,19 +117,18 @@ class LlvmTypeConverter {
// Zeros the padding bits of the given LLVM value representing an XLS value of
// the given XLS type. Bits-typed XLS values are padded out to powers of two.
llvm::Value* ClearPaddingBits(llvm::Value* value, Type* xls_type,
llvm::IRBuilder<>& builder) const;
llvm::IRBuilder<>& builder);

// Returns a mask which is 0 in padded bit positions and 1 in non-padding bit
// positions for the LLVM representation of the given XLS type. For example,
// if the XLS type is bits[3] represented using an i8 in LLVM, PaddingMask
// would return:
//
// i8:0b0000_0111
llvm::Value* PaddingMask(Type* xls_type, llvm::IRBuilder<>& builder) const;
llvm::Value* PaddingMask(Type* xls_type, llvm::IRBuilder<>& builder);

// Returns the bitwise NOT of padding mask, e.g., 0b1111_1000.
llvm::Value* InvertedPaddingMask(Type* xls_type,
llvm::IRBuilder<>& builder) const;
llvm::Value* InvertedPaddingMask(Type* xls_type, llvm::IRBuilder<>& builder);

// Converts the given LLVM value representing and XLS value of the given type
// into a signed representation. This involves extending the sign-bit of the
Expand All @@ -145,7 +144,7 @@ class LlvmTypeConverter {
// returning.
llvm::Value* AsSignedValue(
llvm::Value* value, Type* xls_type, llvm::IRBuilder<>& builder,
std::optional<llvm::Type*> dest_type = std::nullopt) const;
std::optional<llvm::Type*> dest_type = std::nullopt);

// Creates a TypeLayout object describing the native layout of given xls type.
TypeLayout CreateTypeLayout(Type* xls_type);
Expand All @@ -165,6 +164,7 @@ class LlvmTypeConverter {

llvm::LLVMContext& context_;
llvm::DataLayout data_layout_;
TypeCache type_cache_;
};

} // namespace xls
Expand Down
291F
0