8000 Improving build time by removing the gfx11xx and host code from rccl_float8.h by mberenjk · Pull Request #1789 · ROCm/rccl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Improving build time by removing the gfx11xx and host code from rccl_float8.h #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions src/include/rccl_float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,17 @@ typedef struct
} rccl_bfloat8;

// __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
#elif HIP_VERSION >= 60300000
#elif HIP_VERSION >= 60300000 && !(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1030__))

#include <hip/hip_fp8.h>

#if __HIP_DEVICE_COMPILE__ && (defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) || (defined(__gfx1100__) || defined(__gfx1101__)))//HIP_FP8_TYPE_OCP is enabled.
typedef __hip_fp8_e4m3 rccl_float8;
typedef __hip_fp8_e5m2 rccl_bfloat8;
#elif __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
#if __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
typedef __hip_fp8_e4m3_fnuz rccl_float8;
typedef __hip_fp8_e5m2_fnuz rccl_bfloat8;
#else
typedef __hip_fp8_e4m3 rccl_float8;
typedef __hip_fp8_e5m2 rccl_bfloat8;
#endif

inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
{
return os << float(f8);
}

inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8)
{
return os << float(bf8);
}

inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b)
{
return float(a) * float(b);
}

inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b)
{
return float(a) * float(b);
}

inline __host__ __device__ float operator*(rccl_float8 a, float b)
{
return float(a) * float(b);
}

inline __host__ __device__ float operator*(rccl_bfloat8 a, float b)
{
return float(a) * float(b);
}

// For older versions of ROCm that do not include hip_fp8.h,
// we provide a local version of the header file as a fallback.
Expand Down
4 changes: 2 additions & 2 deletions test/common/CollectiveArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ namespace RcclUnitTesting
case ncclUint32: ss << scalarsPerRank.U4[this->globalRank]; break;
case ncclInt64: ss << scalarsPerRank.I8[this->globalRank]; break;
case ncclUint64: ss << scalarsPerRank.U8[this->globalRank]; break;
case ncclFloat8e4m3: ss << scalarsPerRank.F1[this->globalRank]; break;
case ncclFloat8e4m3: ss << (float)scalarsPerRank.F1[this->globalRank]; break;
case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break;
case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break;
case ncclFloat8e5m2: ss << scalarsPerRank.B1[this->globalRank]; break;
case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break;
case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break;
default: ss << "(UNKNOWN)";
}
Expand Down
4 changes: 2 additions & 2 deletions test/common/PtrUnion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ namespace RcclUnitTesting
case ncclUint32: U4[idx] *= scalarsPerRank.U4[rank]; break;
case ncclInt64: I8[idx] *= scalarsPerRank.I8[rank]; break;
case ncclUint64: U8[idx] *= scalarsPerRank.U8[rank]; break;
case ncclFloat8e4m3: F1[idx] = rccl_float8(F1[idx] * scalarsPerRank.F1[rank]); break;
case ncclFloat8e4m3: F1[idx] = rccl_float8((float)F1[idx] * (float)scalarsPerRank.F1[rank]); break;
case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx]) * __half2float(scalarsPerRank.F2[rank])); break;
case ncclFloat32: F4[idx] *= scalarsPerRank.F4[rank]; break;
case ncclFloat64: F8[idx] *= scalarsPerRank.F8[rank]; break;
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(B1[idx] * scalarsPerRank.B1[rank]); break;
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8((float)B1[idx] * (float)scalarsPerRank.B1[rank]); break;
case ncclBfloat16: B2[idx] *= scalarsPerRank.B2[rank]; break;
default:
ERROR("Unsupported datatype\n");
Expand Down
0