8000 Bump to MLX v0.15.2 by andresy · Pull Request #31 · ml-explore/mlx-c · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Bump to MLX v0.15.2 #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ set(MLX_BUILD_PYTHON_BINDINGS OFF)
FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.14.0)
GIT_TAG v0.15.2)
FetchContent_MakeAvailable(mlx)

# ----------------------------- lib -----------------------------
Expand All @@ -39,6 +39,8 @@ set(mlxc-src
${CMAKE_CURRENT_LIST_DIR}/mlx/c/closure.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/compile.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/device.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/distributed_group.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/error.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/fast.cpp
${CMAKE_CURRENT_LIST_DIR}/mlx/c/fft.cpp
Expand Down
5 changes: 5 additions & 0 deletions docs/src/distributed_group.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Distributed Group
=================

.. doxygengroup:: mlx_distributed_group
:content-only:
5 changes: 5 additions & 0 deletions docs/src/distributed_ops.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Distributed Operations
======================

.. doxygengroup:: distributed_ops
:content-only:
2 changes: 2 additions & 0 deletions docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ bindings to MLX.
map
closure
future
distributed_group
ioutils

.. toctree::
Expand All @@ -51,6 +52,7 @@ bindings to MLX.
random
io
transforms
distributed_ops
compile
fast
metal
1 change: 1 addition & 0 deletions mlx/c/compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
30 changes: 30 additions & 0 deletions mlx/c/distributed.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#include "mlx/c/ops.h"

#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
#include "mlx/c/private/stream.h"
#include "mlx/c/private/string.h"
#include "mlx/c/private/utils.h"

extern "C" mlx_array mlx_distributed_all_gather(
mlx_array x,
mlx_distributed_group group) {
RETURN_MLX_C_ARRAY(mlx::core::distributed::all_gather(
x->ctx, (group ? std::make_optional(group->ctx) : std::nullopt)));
}
extern "C" mlx_array mlx_distributed_all_sum(
mlx_array x,
mlx_distributed_group group) {
RETURN_MLX_C_ARRAY(mlx::core::distributed::all_sum(
x->ctx, (group ? std::make_optional(group->ctx) : std::nullopt)));
}
36 changes: 36 additions & 0 deletions mlx/c/distributed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */

#ifndef MLX_OPS_H
#define MLX_OPS_H

#include <stdio.h>

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
* \defgroup distributed_ops Distributed collectives
*/
/**@{*/
mlx_array mlx_distributed_all_gather(mlx_array x, mlx_distributed_group group);
mlx_array mlx_distributed_all_sum(mlx_array x, mlx_distributed_group group);
/**@}*/

#ifdef __cplusplus
}
#endif

#endif
34 changes: 34 additions & 0 deletions mlx/c/distributed_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright © 2023-2024 Apple Inc. */

#include <cstring>

#include "mlx/c/distributed_group.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/stream.h"
#include "mlx/c/private/string.h"
#include "mlx/c/private/utils.h"

mlx_string_* mlx_distributed_group_::tostring() {
RETURN_MLX_C_STRING("mlx_distributed_group");
}

extern "C" int mlx_distributed_group_rank(mlx_distributed_group group) {
return group->ctx.rank();
}

extern "C" int mlx_distributed_group_size(mlx_distributed_group group) {
return group->ctx.size();
}

extern "C" mlx_distributed_group
mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
RETURN_MLX_C_DISTRIBUTED_GROUP(group->ctx.split(color, key));
}

extern "C" bool mlx_distributed_is_available() {
return mlx::core::distributed::is_available();
}

extern "C" mlx_distributed_group mlx_distributed_init(bool strict) {
RETURN_MLX_C_DISTRIBUTED_GROUP(mlx::core::distributed::init(strict));
}
54 changes: 54 additions & 0 deletions mlx/c/distributed_group.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright © 2023-2024 Apple Inc. */

#ifndef MLX_DISTRIBUTED_GROUP_H
#define MLX_DISTRIBUTED_GROUP_H

#include "mlx/c/stream.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
* \defgroup mlx_distributed_group MLX distributed
*/
/**@{*/

/**
* A MLX distributed group object.
*/
typedef struct mlx_distributed_group_* mlx_distributed_group;

/**
* Get the rank.
*/
int mlx_distributed_group_rank(mlx_distributed_group group);

/**
* Get the group size.
*/
int mlx_distributed_group_size(mlx_distributed_group group);

/**
* Split the group.
*/
mlx_distributed_group
mlx_distributed_group_split(mlx_distributed_group group, int color, int key);

/**
* Check if distributed is available.
*/
bool mlx_distributed_is_available();

/**
* Initialize distributed.
*/
mlx_distributed_group mlx_distributed_init(bool strict);

/**@}*/

#ifdef __cplusplus
}
#endif

#endif
1 change: 1 addition & 0 deletions mlx/c/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
6 changes: 6 additions & 0 deletions mlx/c/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down Expand Up @@ -1041,6 +1042,11 @@ mlx_var_all(mlx_array a, bool keepdims, int ddof, mlx_stream s) {
RETURN_MLX_C_ARRAY(mlx::core::var(a->ctx, keepdims, ddof, s->ctx));
}
extern "C" mlx_array
mlx_view(mlx_array a, mlx_array_dtype dtype, mlx_stream s) {
RETURN_MLX_C_ARRAY(
mlx::core::view(a->ctx, MLX_CPP_ARRAY_DTYPE(dtype), s->ctx));
}
extern "C" mlx_array
mlx_where(mlx_array condition, mlx_array x, mlx_array y, mlx_stream s) {
RETURN_MLX_C_ARRAY(mlx::core::where(condition->ctx, x->ctx, y->ctx, s->ctx));
}
Expand Down
2 changes: 2 additions & 0 deletions mlx/c/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down Expand Up @@ -498,6 +499,7 @@ mlx_array mlx_var(
int ddof,
mlx_stream s);
mlx_array mlx_var_all(mlx_array a, bool keepdims, int ddof, mlx_stream s);
mlx_array mlx_view(mlx_array a, mlx_array_dtype dtype, mlx_stream s);
mlx_array
mlx_where(mlx_array condition, mlx_array x, mlx_array y, mlx_stream s);
mlx_array mlx_zeros(
Expand Down
17 changes: 17 additions & 0 deletions mlx/c/private/distributed_group.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/* Copyright © 2023-2024 Apple Inc. */

#ifndef MLX_DISTRIBUTED_GROUP_PRIVATE_H
#define MLX_DISTRIBUTED_GROUP_PRIVATE_H

#include "mlx/c/distributed_group.h"
#include "mlx/c/private/object.h"
#include "mlx/mlx.h"

struct mlx_distributed_group_ : mlx_object_ {
mlx_distributed_group_(mlx::core::distributed::Group ctx)
: mlx_object_(), ctx(ctx){};
virtual mlx_string_* tostring() override;
mlx::core::distributed::Group ctx;
};

#endif
2 changes: 2 additions & 0 deletions mlx/c/private/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,7 @@ static mlx_array_dtype mlx_c_dtypes[] = {
#define RETURN_MLX_C_STRING(str) RETURN_MLX_C_PTR(new mlx_string_(str))
#define RETURN_MLX_C_SAFETENSORS(st) RETURN_MLX_C_PTR(new mlx_safetensors_(st))
#define RETURN_MLX_C_FUTURE(f) RETURN_MLX_C_PTR(new mlx_future_(f))
#define RETURN_MLX_C_DISTRIBUTED_GROUP(group) \
RETURN_MLX_C_PTR(new mlx_distributed_group_(group))

#endif
1 change: 1 addition & 0 deletions mlx/c/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/c/mlx.h"
#include "mlx/c/private/array.h"
#include "mlx/c/private/closure.h"
#include "mlx/c/private/distributed_group.h"
#include "mlx/c/private/future.h"
#include "mlx/c/private/io.h"
#include "mlx/c/private/map.h"
Expand Down
1 change: 1 addition & 0 deletions mlx/c/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/future.h"
#include "mlx/c/ioutils.h"
#include "mlx/c/map.h"
Expand Down
Loading
0