8000 refactor(shares): extract shares math to lib by Rubilmax · Pull Request #94 · morpho-org/morpho-blue · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor(shares): extract shares math to lib #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 11, 2023
51 changes: 21 additions & 30 deletions src/Blue.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import {IIrm} from "src/interfaces/IIrm.sol";
import {IERC20} from "src/interfaces/IERC20.sol";
import {IOracle} from "src/interfaces/IOracle.sol";

import {WadRayMath} from "morpho-utils/math/WadRayMath.sol";
import {SharesMath} from "./libraries/SharesMath.sol";
import {FixedPointMathLib} from "solmate/utils/FixedPointMathLib.sol";
import {SafeTransferLib} from "src/libraries/SafeTransferLib.sol";

uint256 constant WAD = 1e18;
Expand All @@ -31,7 +32,8 @@ function toId(Market calldata market) pure returns (Id) {
}

contract Blue {
using WadRayMath for uint256;
using SharesMath for uint256;
using FixedPointMathLib for uint256;
< 10000 /td> using SafeTransferLib for IERC20;

// Storage.
Expand Down Expand Up @@ -107,14 +109,9 @@ contract Blue {

accrueInterests(market, id);

if (totalSupply[id] == 0) {
supplyShare[id][msg.sender] = WAD;
totalSupplyShares[id] = WAD;
} else {
uint256 shares = amount.wadDivDown(totalSupply[id]).wadMulDown(totalSupplyShares[id]);
supplyShare[id][msg.sender] += shares;
totalSupplyShares[id] += shares;
}
uint256 shares = amount.toSharesDown(totalSupply[id], totalSupplyShares[id]);
supplyShare[id][msg.sender] += shares;
totalSupplyShares[id] += shares;

totalSupply[id] += amount;
Comment on lines +112 to 116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there is a vulnerability in this implementation.
Virtual shares are taken into account for the shares calculations (in the SharesMath lib), but are never added to totalSupplyShares.
So it seems that the first supplier still owns all the market shares, and can still perform the inflation attack.
Shouldn’t we add the virtual shares to totalSupplyShares at the initialization of the market (or if totalSupplyShares==0) ?
Maybe I’m wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because what I understand from https://docs.openzeppelin.com/contracts/4.x/erc4626#inflation-attack is that the virtual shares are meant to catch part of the attacker’s donation (that aims to inflate the rate) and therefore make his attack unprofitable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same for totalSupply : if totalSupply[id] == 0, we sould add 1 to totalSupply[id]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor
@QGarchery QGarchery Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, when you do any computation with the shares, you take into account the virtual shares and the virtual assets (only 1 virtual asset actually). You don't add those virtual shares and assets to the mappings representing the actual values.

I think that this implementation works, because those virtual shares actually catch part of the attacker's donation. For example, by writing V as the virtual shares (V = $10^\delta$):

  1. The attacker supplies a0 assets.
    The added shares are toSharesDown(a0, 0, 0) = a0 * V
assets shares
a0 a0 * V
  1. The attacker transfers a1 assets.
assets shares
a0 + a1 a0 * V
  1. A user supplies u assets.
    The added shares are toSharesDown(u, a0 + a1, a0 * V) = u * (V + a0 * V) / (1 + a0 + a1) = V * u * (1 + a0) / (1 + a0 + a1), which is what is written in the openzeppelin page for the number of shares that the user gets

Keep in mind that this manipulation cannot be done on Morpho Blue, because it is not relying on the balance (so there is no way to modify the total assets simply by transferring tokens to the contract). Instead, the purpose of this PR is to ensure a high initial conversion rate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @QGarchery I agree it's ok


Expand All @@ -128,7 +125,7 @@ contract Blue {

accrueInterests(market, id);

uint256 shares = amount.wadDivUp(totalSupply[id]).wadMulUp(totalSupplyShares[id]);
uint256 shares = amount.toSharesUp(totalSupply[id], totalSupplyShares[id]);
supplyShare[id][msg.sender] -= shares;
totalSupplyShares[id] -= shares;

Expand All @@ -148,14 +145,9 @@ contract Blue {

accrueInterests(market, id);

if (totalBorrow[id] == 0) {
borrowShare[id][msg.sender] = WAD;
totalBorrowShares[id] = WAD;
} else {
uint256 shares = amount.wadDivUp(totalBorrow[id]).wadMulUp(totalBorrowShares[id]);
borrowShare[id][msg.sender] += shares;
totalBorrowShares[id] += shares;
}
uint256 shares = amount.toSharesUp(totalBorrow[id], totalBorrowShares[id]);
borrowShare[id][msg.sender] += shares;
totalBorrowShares[id] += shares;

totalBorrow[id] += amount;

Expand All @@ -172,7 +164,7 @@ contract Blue {

accrueInterests(market, id);

uint256 shares = amount.wadDivDown(totalBorrow[id]).wadMulDown(totalBorrowShares[id]); // TODO: totalBorrow[id] > 0 because ???
uint256 shares = amount.toSharesDown(totalBorrow[id], totalBorrowShares[id]);
borrowShare[id][msg.sender] -= shares;
totalBorrowShares[id] -= shares;

Expand Down Expand Up @@ -222,11 +214,11 @@ contract Blue {
require(!isHealthy(market, id, borrower), "cannot liquidate a healthy position");

// The liquidation incentive is 1 + ALPHA * (1 / LLTV - 1).
uint256 incentive = WAD + ALPHA.wadMulDown(WAD.wadDivDown(market.lltv) - WAD);
uint256 repaid = seized.wadMulUp(market.collateralOracle.price()).wadDivUp(incentive).wadDivUp(
uint256 incentive = WAD + ALPHA.mulWadDown(WAD.divWadDown(market.lltv) - WAD);
uint256 repaid = seized.mulWadUp(market.collateralOracle.price()).divWadUp(incentive).divWadUp(
market.borrowableOracle.price()
);
uint256 repaidShares = repaid.wadDivDown(totalBorrow[id]).wadMulDown(totalBorrowShares[id]);
uint256 repaidShares = repaid.toSharesDown(totalBorrow[id], totalBorrowShares[id]);

borrowShare[id][borrower] -= repaidShares;
totalBorrowShares[id] -= repaidShares;
Expand All @@ -236,7 +228,7 @@ contract Blue {

// Realize the bad debt if needed.
if (collateral[id][borrower] == 0) {
totalSupply[id] -= borrowShare[id][borrower].wadDivUp(totalBorrowShares[id]).wadMulUp(totalBorrow[id]);
totalSupply[id] -= borrowShare[id][borrower].toAssetsUp(totalBorrow[id], totalBorrowShares[id]);
totalBorrowShares[id] -= borrowShare[id][borrower];
borrowShare[id][borrower] = 0;
}
Expand All @@ -252,7 +244,7 @@ contract Blue {

if (marketTotalBorrow != 0) {
uint256 borrowRate = market.irm.borrowRate(market);
uint256 accruedInterests = marketTotalBorrow.wadMulDown(borrowRate * (block.timestamp - lastUpdate[id]));
uint256 accruedInterests = marketTotalBorrow.mulWadDown(borrowRate * (block.timestamp - lastUpdate[id]));
totalBorrow[id] = marketTotalBorrow + accruedInterests;
totalSupply[id] += accruedInterests;
}
Expand All @@ -267,10 +259,9 @@ contract Blue {
if (borrowShares == 0) return true;

// totalBorrowShares[id] > 0 when borrowShares > 0.
uint256 borrowValue = borrowShares.wadDivUp(totalBorrowShares[id]).wadMulUp(totalBorrow[id]).wadMulUp(
market.borrowableOracle.price()
);
uint256 collateralValue = collateral[id][user].wadMulDown(market.collateralOracle.price());
return collateralValue.wadMulDown(market.lltv) >= borrowValue;
uint256 borrowValue =
borrowShares.toAssetsUp(totalBorrow[id], totalBorrowShares[id]).mulWadUp(market.borrowableOracle.price());
uint256 collateralValue = collateral[id][user].mulWadDown(market.collateralOracle.price());
return collateralValue.mulWadDown(market.lltv) >= borrowValue;
}
}
37 changes: 37 additions & 0 deletions src/libraries/SharesMath.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import {FixedPointMathLib} from "solmate/utils/FixedPointMathLib.sol";

/// @notice Shares management library.
/// @dev This implementation mitigates share price manipulations, using OpenZeppelin's virtual shares: https://docs.openzeppelin.com/contracts/4.x/erc4626#inflation-attack.
library SharesMath {
using FixedPointMathLib for uint256;

uint256 internal constant VIRTUAL_SHARES = 1e18;
uint256 internal constant VIRTUAL_ASSETS = 1;

/// @dev Calculates the value of the given assets quoted in shares, rounding down.
/// Note: provided that assets <= totalAssets, this function satisfies the invariant: shares <= totalShares.
function toSharesDown(uint256 assets, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) {
return assets.mulDivDown(totalShares + VIRTUAL_SHARES, totalAssets + VIRTUAL_ASSETS);
}

/// @dev Calculates the value of the given shares quoted in assets, rounding down.
/// Note: provided that shares <= totalShares, this function satisfies the invariant: assets <= totalAssets.
function toAssetsDown(uint256 shares, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) {
return shares.mulDivDown(totalAssets + VIRTUAL_ASSETS, totalShares + VIRTUAL_SHARES);
}

/// @dev Calculates the value of the given assets quoted in shares, rounding up.
/// Note: provided that assets <= totalAssets, this function satisfies the invariant: shares <= totalShares + VIRTUAL_SHARES.
function toSharesUp(uint256 assets, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) {
return assets.mulDivUp(totalShares + VIRTUAL_SHARES, totalAssets + VIRTUAL_ASSETS);
}

/// @dev Calculates the value of the given shares quoted in assets, rounding up.
/// Note: provided that shares <= totalShares, this function satisfies the invariant: assets <= totalAssets + VIRTUAL_SHARES.
function toAssetsUp(uint256 shares, uint256 totalAssets, uint256 totalShares) internal pure returns (uint256) {
return shares.mulDivUp(totalAssets + VIRTUAL_ASSETS, totalShares + VIRTUAL_SHARES);
}
}
6 changes: 2 additions & 4 deletions src/mocks/IrmMock.sol
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
// SPDX-License-Identifier: UNLICENSED
F438 pragma solidity 0.8.20;

import {WadRayMath} from "morpho-utils/math/WadRayMath.sol";

import "src/Blue.sol";

contract IrmMock is IIrm {
using WadRayMath for uint256;
using FixedPointMathLib for uint256;

Blue public immutable blue;

Expand All @@ -16,7 +14,7 @@ contract IrmMock is IIrm {

function borrowRate(Market calldata market) external view returns (uint256) {
Id id = Id.wrap(keccak256(abi.encode(market)));
uint256 utilization = blue.totalBorrow(id).wadDivDown(blue.totalSupply(id));
uint256 utilization = blue.totalBorrow(id).divWadDown(blue.totalSupply(id));

// Divide by the number of seconds in a year.
// This is a very simple model (to refine later) where x% utilization corresponds to x% APR.
Expand Down
72 changes: 40 additions & 32 deletions test/forge/Blue.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ pragma solidity 0.8.20;
import {IERC20} from "src/interfaces/IERC20.sol";
import {IOracle} from "src/interfaces/IOracle.sol";

import {WadRayMath} from "morpho-utils/math/WadRayMath.sol";

import "forge-std/Test.sol";
import "forge-std/console.sol";

Expand All @@ -15,7 +13,7 @@ import {OracleMock as Oracle} from "src/mocks/OracleMock.sol";
import {IrmMock as Irm} from "src/mocks/IrmMock.sol";

contract BlueTest is Test {
using WadRayMath for uint256;
using FixedPointMathLib for uint256;

address private constant BORROWER = address(1234);
address private constant LIQUIDATOR = address(5678);
Expand Down Expand Up @@ -78,8 +76,8 @@ contract BlueTest is Test {
// To move to a test utils file later.

function netWorth(address user) internal view returns (uint256) {
uint256 collateralAssetValue = collateralAsset.balanceOf(user).wadMulDown(collateralOracle.price());
uint256 borrowableAssetValue = borrowableAsset.balanceOf(user).wadMulDown(borrowableOracle.price());
uint256 collateralAssetValue = collateralAsset.balanceOf(user).mulWadDown(collateralOracle.price());
uint256 borrowableAssetValue = borrowableAsset.balanceOf(user).mulWadDown(borrowableOracle.price());
return collateralAssetValue + borrowableAssetValue;
}

Expand All @@ -89,7 +87,7 @@ contract BlueTest is Test {

uint256 totalShares = blue.totalSupplyShares(id);
uint256 totalSupply = blue.totalSupply(id);
return supplyShares.wadMulDown(totalSupply).wadDivDown(totalShares);
return supplyShares.divWadDown(totalShares).mulWadDown(totalSupply);
}

function borrowBalance(address user) internal view returns (uint256) {
Expand All @@ -98,7 +96,7 @@ contract BlueTest is Test {

uint256 totalShares = blue.totalBorrowShares(id);
uint256 totalBorrow = blue.totalBorrow(id);
return borrowerShares.wadMulUp(totalBorrow).wadDivUp(totalShares);
return borrowerShares.divWadUp(totalShares).mulWadUp(totalBorrow);
}

// Invariants
Expand Down Expand Up @@ -209,7 +207,7 @@ contract BlueTest is Test {
borrowableAsset.setBalance(address(this), amount);
blue.supply(market, amount);

assertEq(blue.supplyShare(id, address(this)), 1e18, "supply share");
assertEq(blue.supplyShare(id, address(this)), amount * SharesMath.VIRTUAL_SHARES, "supply share");
assertEq(borrowableAsset.balanceOf(address(this)), 0, "lender balance");
assertEq(borrowableAsset.balanceOf(address(blue)), amount, "blue balance");
}
Expand All @@ -236,7 +234,7 @@ contract BlueTest is Test {
vm.prank(BORROWER);
blue.borrow(market, amountBorrowed);

assertEq(blue.borrowShare(id, BORROWER), 1e18, "borrow share");
assertEq(blue.borrowShare(id, BORROWER), amountBorrowed * SharesMath.VIRTUAL_SHARES, "borrow share");
assertEq(borrowableAsset.balanceOf(BORROWER), amountBorrowed, "BORROWER balance");
assertEq(borrowableAsset.balanceOf(address(blue)), amountLent - amountBorrowed, "blue balance");
}
Expand Down Expand Up @@ -266,7 +264,10 @@ contract BlueTest is Test {
blue.withdraw(market, amountWithdrawn);

assertApproxEqAbs(
blue.supplyShare(id, address(this)), (amountLent - amountWithdrawn) * 1e18 / amountLent, 1e3, "supply share"
blue.supplyShare(id, address(this)),
(amountLent - amountWithdrawn) * SharesMath.VIRTUAL_SHARES,
100,
"supply share"
);
assertEq(borrowableAsset.balanceOf(address(this)), amountWithdrawn, "this balance");
assertEq(
Expand Down Expand Up @@ -296,9 +297,9 @@ contract BlueTest is Test {
vm.prank(BORROWER);
blue.supplyCollateral(market, amountCollateral);

uint256 collateralValue = amountCollateral.wadMulDown(priceCollateral);
uint256 borrowValue = amountBorrowed.wadMulUp(priceBorrowable);
if (borrowValue == 0 || (collateralValue > 0 && borrowValue <= collateralValue.wadMulDown(LLTV))) {
uint256 collateralValue = amountCollateral.mulWadDown(priceCollateral);
uint256 borrowValue = amountBorrowed.mulWadUp(priceBorrowable);
if (borrowValue == 0 || (collateralValue > 0 && borrowValue <= collateralValue.mulWadDown(LLTV))) {
vm.prank(BORROWER);
blue.borrow(market, amountBorrowed);
} else {
Expand All @@ -322,7 +323,10 @@ contract BlueTest is Test {
vm.stopPrank();

assertApproxEqAbs(
blue.borrowShare(id, BORROWER), (amountBorrowed - amountRepaid) * 1e18 / amountBorrowed, 1e3, "borrow share"
blue.borrowShare(id, BORROWER),
(amountBorrowed - amountRepaid) * SharesMath.VIRTUAL_SHARES,
100,
"borrow share"
);
assertEq(borrowableAsset.balanceOf(BORROWER), amountBorrowed - amountRepaid, "BORROWER balance");
assertEq(borrowableAsset.balanceOf(address(blue)), amountLent - amountBorrowed + amountRepaid, "blue balance");
Expand Down Expand Up @@ -364,10 +368,10 @@ contract BlueTest is Test {
amountLent = bound(amountLent, 1000, 2 ** 64);

uint256 amountCollateral = amountLent;
uint256 borrowingPower = amountCollateral.wadMulDown(LLTV);
uint256 amountBorrowed = borrowingPower.wadMulDown(0.8e18);
uint256 toSeize = amountCollateral.wadMulDown(LLTV);
uint256 incentive = WAD + ALPHA.wadMulDown(WAD.wadDivDown(LLTV) - WAD);
uint256 borrowingPower = amountCollateral.mulWadDown(LLTV);
uint256 amountBorrowed = borrowingPower.mulWadDown(0.8e18);
uint256 toSeize = amountCollateral.mulWadDown(LLTV);
uint256 incentive = WAD + ALPHA.mulWadDown(WAD.divWadDown(LLTV) - WAD);

borrowableAsset.setBalance(address(this), amountLent);
collateralAsset.setBalance(BORROWER, amountCollateral);
Expand All @@ -394,9 +398,9 @@ contract BlueTest is Test {
uint256 liquidatorNetWorthAfter = netWorth(LIQUIDATOR);

uint256 expectedRepaid =
toSeize.wadMulUp(collateralOracle.price()).wadDivUp(incentive).wadDivUp(borrowableOracle.price());
uint256 expectedNetWorthAfter = liquidatorNetWorthBefore + toSeize.wadMulDown(collateralOracle.price())
- expectedRepaid.wadMulDown(borrowableOracle.price());
toSeize.mulWadUp(collateralOracle.price()).divWadUp(incentive).divWadUp(borrowableOracle.price());
uint256 expectedNetWorthAfter = liquidatorNetWorthBefore + toSeize.mulWadDown(collateralOracle.price())
- expectedRepaid.mulWadDown(borrowableOracle.price());
assertEq(liquidatorNetWorthAfter, expectedNetWorthAfter, "LIQUIDATOR net worth");
assertApproxEqAbs(borrowBalance(BORROWER), amountBorrowed - expectedRepaid, 100, "BORROWER balance");
assertEq(blue.collateral(id, BORROWER), amountCollateral - toSeize, "BORROWER collateral");
Expand All @@ -407,10 +411,10 @@ contract BlueTest is Test {
amountLent = bound(amountLent, 1000, 2 ** 64);

uint256 amountCollateral = amountLent;
uint256 borrowingPower = amountCollateral.wadMulDown(LLTV);
uint256 amountBorrowed = borrowingPower.wadMulDown(0.8e18);
uint256 borrowingPower = amountCollateral.mulWadDown(LLTV);
uint256 amountBorrowed = borrowingPower.mulWadDown(0.8e18);
uint256 toSeize = amountCollateral;
uint256 incentive = WAD + ALPHA.wadMulDown(WAD.wadDivDown(market.lltv) - WAD);
uint256 incentive = WAD + ALPHA.mulWadDown(WAD.divWadDown(market.lltv) - WAD);

borrowableAsset.setBalance(address(this), amountLent);
collateralAsset.setBalance(BORROWER, amountCollateral);
Expand All @@ -437,9 +441,9 @@ contract BlueTest is Test {
uint256 liquidatorNetWorthAfter = netWorth(LIQUIDATOR);

uint256 expectedRepaid =
toSeize.wadMulUp(collateralOracle.price()).wadDivUp(incentive).wadDivUp(borrowableOracle.price());
uint256 expectedNetWorthAfter = liquidatorNetWorthBefore + toSeize.wadMulDown(collateralOracle.price())
- expectedRepaid.wadMulDown(borrowableOracle.price());
toSeize.mulWadUp(collateralOracle.price()).divWadUp(incentive).divWadUp(borrowableOracle.price());
uint256 expectedNetWorthAfter = liquidatorNetWorthBefore + toSeize.mulWadDown(collateralOracle.price())
- expectedRepaid.mulWadDown(borrowableOracle.price());
assertEq(liquidatorNetWorthAfter, expectedNetWorthAfter, "LIQUIDATOR net worth");
assertEq(borrowBalance(BORROWER), 0, "BORROWER balance");
assertEq(blue.collateral(id, BORROWER), 0, "BORROWER collateral");
Expand All @@ -460,9 +464,13 @@ contract BlueTest is Test {
blue.supply(market, secondAmount);

assertApproxEqAbs(supplyBalance(address(this)), firstAmount, 100, "same balance first user");
assertEq(blue.supplyShare(id, address(this)), 1e18, "expected shares first user");
assertEq(
blue.supplyShare(id, address(this)), firstAmount * SharesMath.VIRTUAL_SHARES, "expected shares first user"
);
assertApproxEqAbs(supplyBalance(BORROWER), secondAmount, 100, "same balance second user");
assertEq(blue.supplyShare(id, BORROWER), secondAmount * 1e18 / firstAmount, "expected shares second user");
assertApproxEqAbs(
blue.supplyShare(id, BORROWER), secondAmount * SharesMath.VIRTUAL_SHARES, 100, "expected shares second user"
);
}

function testUnknownMarket(Market memory marketFuzz) public {
Expand Down Expand Up @@ -514,12 +522,12 @@ contract BlueTest is Test {
}

function testEmptyMarket(uint256 amount) public {
vm.assume(amount > 0);
amount = bound(amount, 1, type(uint256).max / SharesMath.VIRTUAL_SHARES);

vm.expectRevert();
vm.expectRevert(stdError.arithmeticError);
blue.withdraw(market, amount);

vm.expectRevert();
vm.expectRevert(stdError.arithmeticError);
blue.repay(market, amount);

vm.expectRevert(stdError.arithmeticError);
Expand Down
0