8000 feat: constify mul* by DaniPopes · Pull Request #449 · recmo/uint · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: constify mul* #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

8000
Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Updated Pyo3. This is a **non-semver breaking change** to address a vulnerability reported on Pyo3. ([#460])
- Make `rotate*`, `*sh[lr]` functions `const` ([#441])
- Make `mul*` functions `const` ([#449])

[#441]: https://github.com/recmo/uint/pull/441
[#449]: https://github.com/recmo/uint/pull/449
[#460]: https://github.com/recmo/uint/pull/460

## [1.14.0] - 2025-03-25
Expand Down
55 changes: 36 additions & 19 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,10 @@ pub use self::{

trait DoubleWord<T>: Sized + Copy {
fn join(high: T, low: T) -> Self;
fn add(a: T, b: T) -> Self;
fn mul(a: T, b: T) -> Self;
fn muladd(a: T, b: T, c: T) -> Self;
fn muladd2(a: T, b: T, c: T, d: T) -> Self;

fn high(self) -> T;
fn low(self) -> T;
fn split(self) -> (T, T);
}

impl DoubleWord<u64> for u128 {
Expand All @@ -43,45 +39,66 @@ impl DoubleWord<u64> for u128 {
(Self::from(high) << 64) | Self::from(low)
}

/// Computes `a + b` as a 128-bit value.
#[inline(always)]
fn add(a: u64, b: u64) -> Self {
Self::from(a) + Self::from(b)
}

/// Computes `a * b` as a 128-bit value.
#[inline(always)]
fn mul(a: u64, b: u64) -> Self {
Self::from(a) * Self::from(b)
}

#[inline(always)]
fn high(self) -> u64 {
(self >> 64) as u64
}

#[inline(always)]
#[allow(clippy::cast_possible_truncation)]
fn low(self) -> u64 {
self as u64
}
}

#[derive(Clone, Copy)]
struct ConstDoubleWord<T>(T);

impl ConstDoubleWord<u128> {
#[inline(always)]
const fn ext(a: u64) -> u128 {
a as u128
}

/// Computes `a + b` as a 128-bit value.
#[inline(always)]
const fn add(a: u64, b: u64) -> Self {
Self(Self::ext(a) + Self::ext(b))
}

/// Computes `a * b + c` as a 128-bit value. Note that this can not
/// overflow.
#[inline(always)]
fn muladd(a: u64, b: u64, c: u64) -> Self {
Self::from(a) * Self::from(b) + Self::from(c)
const fn muladd(a: u64, b: u64, c: u64) -> Self {
Self(Self::ext(a) * Self::ext(b) + Self::ext(c))
}

/// Computes `a * b + c + d` as a 128-bit value. Note that this can not
/// overflow.
#[inline(always)]
fn muladd2(a: u64, b: u64, c: u64, d: u64) -> Self {
Self::from(a) * Self::from(b) + Self::from(c) + Self::from(d)
const fn muladd2(a: u64, b: u64, c: u64, d: u64) -> Self {
Self(Self::ext(a) * Self::ext(b) + Self::ext(c) + Self::ext(d))
}

#[inline(always)]
fn high(self) -> u64 {
(self >> 64) as u64
const fn high(self) -> u64 {
(self.0 >> 64) as u64
}

#[inline(always)]
#[allow(clippy::cast_possible_truncation)]
fn low(self) -> u64 {
self as u64
const fn low(self) -> u64 {
self.0 as u64
}

#[inline(always)]
fn split(self) -> (u64, u64) {
const fn split(self) -> (u64, u64) {
(self.low(), self.high())
}
}
Expand Down
64 changes: 38 additions & 26 deletions src/algorithms/mul.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::module_name_repetitions)]

use crate::algorithms::{ops::sbb, DoubleWord};
use crate::algorithms::{ops::sbb, ConstDoubleWord as DW};

/// ⚠️ Computes `result += a * b` and checks for overflow.
///
Expand All @@ -23,7 +23,7 @@ use crate::algorithms::{ops::sbb, DoubleWord};
/// assert_eq!(result, [12]);
/// ```
#[inline(always)]
pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
pub const fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
// Trim zeros from `a`
while let [0, rest @ ..] = a {
a = rest;
Expand Down Expand Up @@ -55,8 +55,10 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };

// Iterate over limbs of `b` and add partial products to `lhs`.
let mut i = 0;
let mut overflow = false;
for &b in b {
while i < b.len() {
let b = b[i];
if lhs.len() >= a.len() {
let (target, rest) = lhs.split_at_mut(a.len());
let carry = addmul_nx1(target, a, b);
Expand All @@ -67,24 +69,27 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
if lhs.is_empty() {
break;
}
addmul_nx1(lhs, &a[..lhs.len()], b);
addmul_nx1(lhs, a.split_at(lhs.len()).0, b);
}
lhs = &mut lhs[1..];
lhs = lhs.split_at_mut(1).1;
i += 1;
}
overflow
}

/// Computes `lhs += a` and returns the carry.
#[inline(always)]
pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
pub const fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
if a == 0 {
return 0;
}
for lhs in lhs {
(*lhs, a) = u128::add(*lhs, a).split();
let mut i = 0;
while i < lhs.len() {
(lhs[i], a) = DW::add(lhs[i], a).split();
if a == 0 {
return 0;
}
i += 1;
}
a
}
Expand All @@ -95,9 +100,9 @@ pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
///
/// Panics if the lengths are not the same.
#[inline(always)]
pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assert_eq!(lhs.len(), a.len());
assert_eq!(lhs.len(), b.len());
pub const fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assert!(lhs.len() == a.len());
assert!(lhs.len() == b.len());
match lhs.len() {
0 => {}
1 => addmul_1(lhs, a, b),
Expand All @@ -110,7 +115,7 @@ pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {

/// Computes `lhs += a * b` for 1 limb.
#[inline(always)]
fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
const fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 1);
assume!(a.len() == 1);
assume!(b.len() == 1);
Expand All @@ -120,7 +125,7 @@ fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {

/// Computes `lhs += a * b` for 2 limbs.
#[inline(always)]
fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
const fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 2);
assume!(a.len() == 2);
assume!(b.len() == 2);
Expand All @@ -133,7 +138,7 @@ fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {

/// Computes `lhs += a * b` for 3 limbs.
#[inline(always)]
fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
const fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 3);
assume!(a.len() == 3);
assume!(b.len() == 3);
Expand All @@ -150,7 +155,7 @@ fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {

/// Computes `lhs += a * b` for 4 limbs.
#[inline(always)]
fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
const fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
assume!(lhs.len() == 4);
assume!(a.len() == 4);
assume!(b.len() == 4);
Expand All @@ -171,18 +176,20 @@ fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
}

#[inline(always)]
fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
let prod = u128::muladd2(a, b, c, *lhs);
const fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
let prod = DW::muladd2(a, b, c, *lhs);
*lhs = prod.low();
prod.high()
}

/// Computes `lhs *= a` and returns the carry.
#[inline(always)]
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
pub const fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in lhs {
(*lhs, carry) = u128::muladd(*lhs, a, carry).split();
let mut i = 0;
while i < lhs.len() {
(lhs[i], carry) = DW::muladd(lhs[i], a, carry).split();
i += 1;
}
carry
}
Expand All @@ -198,11 +205,13 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
/// }{2^{64⋅N}}} \end{aligned}
/// $$
#[inline(always)]
pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
pub const fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
for i in 0..a.len() {
(lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split();
let mut i = 0;
while i < a.len() {
(lhs[i], carry) = DW::muladd2(a[i], b, carry, lhs[i]).split();
i += 1;
}
carry
}
Expand All @@ -219,17 +228,20 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
/// $$
// OPT: `carry` and `borrow` can probably be merged into a single var.
#[inline(always)]
pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
pub const fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
let mut borrow = 0;
for i in 0..a.len() {
let mut i = 0;
while i < a.len() {
// Compute product limbs
let limb;
(limb, carry) = u128::muladd(a[i], b, carry).split();
(limb, carry) = DW::muladd(a[i], b, carry).split();

// Subtract
(lhs[i], borrow) = sbb(lhs[i], limb, borrow);

i += 1;
}
borrow + carry
}
Expand Down
18 changes: 9 additions & 9 deletions src/algorithms/ops.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use super::DoubleWord;
use super::ConstDoubleWord as DW;

#[inline(always)]
#[must_use]
pub fn adc(lhs: u64, rhs: u64, carry: u64) -> (u64, u64) {
let result = u128::from(lhs) + u128::from(rhs) + u128::from(carry);
result.split()
pub const fn adc(lhs: u64, rhs: u64, carry: u64) -> (u64, u64) {
let result = DW::ext(lhs) + DW::ext(rhs) + DW::ext(carry);
DW(result).split()
}

#[inline(always)]
#[must_use]
pub fn sbb(lhs: u64, rhs: u64, borrow: u64) -> (u64, u64) {
let result = u128::from(lhs)
.wrapping_sub(u128::from(rhs))
.wrapping_sub(u128::from(borrow));
(result.low(), result.high().wrapping_neg())
pub const fn sbb(lhs: u64, rhs: u64, borrow: u64) -> (u64, u64) {
let result = DW::ext(lhs)
.wrapping_sub(DW::ext(rhs))
.wrapping_sub(DW::ext(borrow));
(DW(result).low(), DW(result).high().wrapping_neg())
}
2 changes: 0 additions & 2 deletions src/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// Computes `self / rhs`, returning [`None`] if `rhs == 0`.
#[inline]
#[must_use]
#[allow(clippy::missing_const_for_fn)] // False positive
pub fn checked_div(self, rhs: Self) -> Option<Self> {
if rhs.is_zero() {
return None;
Expand All @@ -16,7 +15,6 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// Computes `self % rhs`, returning [`None`] if `rhs == 0`.
#[inline]
#[must_use]
#[allow(clippy::missing_const_for_fn)] // False positive
pub fn checked_rem(self, rhs: Self) -> Option<Self> {
if rhs.is_zero() {
return None;
Expand Down
6 changes: 2 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,15 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}

#[inline(always)]
fn apply_mask(&mut self) {
const fn apply_mask(&mut self) {
if Self::SHOULD_MASK {
self.limbs[LIMBS - 1] &= Self::MASK;
}
}

#[inline(always)]
const fn masked(mut self) -> Self {
if Self::SHOULD_MASK {
self.limbs[LIMBS - 1] &= Self::MASK;
}
self.apply_mask();
self
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ macro_rules! assume {
macro_rules! debug_unreachable {
($($t:tt)*) => {
if cfg!(debug_assertions) {
unreachable!($($t)*);
panic!($($t)*);
} else {
unsafe { core::hint::unreachable_unchecked() };
}
Expand Down
Loading
Loading
0