8000 [OM] Use custom CAPI wrappers for Object. by mikeurbach · Pull Request #5275 · llvm/circt · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[OM] Use custom CAPI wrappers for Object. #5275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ struct Evaluator {
};

/// A composite Object, which has a type and fields.
struct Object {
/// Enables the shared_from_this functionality so Object pointers can be passed
/// through the CAPI and unwrapped back into C++ smart pointers with the
/// appropriate reference count.
struct Object : std::enable_shared_from_this<Object> {
/// Get the type of the Object.
mlir::Type getType();

Expand Down
28 changes: 21 additions & 7 deletions lib/CAPI/Dialect/OM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,22 @@ bool omTypeIsAClassType(MlirType type) { return unwrap(type).isa<ClassType>(); }
//===----------------------------------------------------------------------===//

DEFINE_C_API_PTR_METHODS(OMEvaluator, circt::om::Evaluator)
DEFINE_C_API_PTR_METHODS(OMObject, std::shared_ptr<circt::om::Object>)

/// Define our own wrap and unwrap instead of using the usual macro. This is To
/// handle the std::shared_ptr reference counts appropriately. We want to always
/// create *new* shared pointers to the Object when we wrap it for C, to
/// increment the reference count. We want to use the shared_from_this
/// functionality to ensure it is unwrapped into C++ with the correct reference
/// count.

static inline OMObject wrap(std::shared_ptr<Object> object) {
return OMObject{static_cast<void *>(
(new std::shared_ptr<Object>(std::move(object)))->get())};
}

static inline std::shared_ptr<Object> unwrap(OMObject c) {
return static_cast<Object *>(c.ptr)->shared_from_this();
}

//===----------------------------------------------------------------------===//
// Evaluator API.
Expand Down Expand Up @@ -74,9 +89,8 @@ OMObject omEvaluatorInstantiate(OMEvaluator evaluator, MlirAttribute className,
if (failed(result))
return OMObject();

// Wrap and return a *new* shared pointer to the Object, to ensure the
// reference count is kept up to date.
return wrap(new std::shared_ptr<Object>(result.value()));
// Wrap and return the Object.
return wrap(result.value());
}

/// Get the Module the Evaluator is built from.
Expand All @@ -97,15 +111,15 @@ bool omEvaluatorObjectIsNull(OMObject object) {

/// Get the Type from an Object, which will be a ClassType.
MlirType omEvaluatorObjectGetType(OMObject object) {
return wrap((*unwrap(object))->getType());
return wrap(unwrap(object)->getType());
}

/// Get a field from an Object, which must contain a field of that name.
OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) {
// Unwrap the Object and get the field of the name, which the client must
// supply as a StringAttr.
FailureOr<ObjectValue> result =
(*unwrap(object))->getField(unwrap(name).cast<StringAttr>());
unwrap(object)->getField(unwrap(name).cast<StringAttr>());

// If getField failed, return a null ObjectValue. A Diagnostic will be emitted
// in this case.
Expand All @@ -114,7 +128,7 @@ OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) {

// If the field is an Object, return an ObjectValue with the Object set.
if (auto *object = std::get_if<std::shared_ptr<Object>>(&result.value()))
return OMObjectValue{MlirAttribute(), wrap(object)};
return OMObjectValue{MlirAttribute(), wrap(*object)};

// If the field is an Attribute, return an ObjectValue with the Primitive set.
if (auto *primitive = std::get_if<Attribute>(&result.value()))
Expand Down
41 changes: 36 additions & 5 deletions test/CAPI/om.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@
#include <stdio.h>

void testEvaluator(MlirContext ctx) {
const char *testIR = "module {"
" om.class @Test(%param: i8) {"
" om.class.field @field, %param : i8"
" }"
"}";
const char *testIR =
"module {"
" om.class @Test(%param: i8) {"
" om.class.field @field, %param : i8"
" %0 = om.object @Child() : () -> !om.class.type<@Child>"
" om.class.field @child, %0 : !om.class.type<@Child>"
" }"
" om.class @Child() {"
" %0 = om.constant 14 : i64"
" om.class.field @foo, %0 : i64"
" }"
"}";

// Set up the Evaluator.
MlirModule testModule =
Expand Down Expand Up @@ -92,6 +99,30 @@ void testEvaluator(MlirContext ctx) {

// CHECK: 42 : i8
mlirAttributeDump(fieldValue);

// Test get field success for child object.

MlirAttribute childFieldName =
mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("child"));

OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName);

OMObject child = omEvaluatorObjectValueGetObject(childField);

// CHECK: 0
fprintf(stderr, "child object is null: %d\n", omEvaluatorObjectIsNull(child));

OMObjectValue foo = omEvaluatorObjectGetField(
child, mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo")));

// CHECK: child object field is primitive: 1
fprintf(stderr, "child object field is primitive: %d\n",
omEvaluatorObjectValueIsAPrimitive(foo));

MlirAttribute fooValue = omEvaluatorObjectValueGetPrimitive(foo);

// CHECK: 14 : i64
mlirAttributeDump(fooValue);
}

int main() {
Expand Down
0