8000 Dependency Aware Node Caching for low RAM/VRAM machines by Chargeuk · Pull Request #7509 · comfyanonymous/ComfyUI · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Dependency Aware Node Caching for low RAM/VRAM machines #7509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class LatentPreviewMethod(enum.Enum):
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand Down
153 changes: 153 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,156 @@ def ensure_subcache_for(self, node_id, children_ids):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self


class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""

def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.

Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.

Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()

# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)

# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)

def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.

Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()

for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)

def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.

Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)

def get(self, node_id):
"""
Retrieve the cached value for a node.

Args:
node_id: The ID of the node to retrieve.

Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)

def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.

Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.

Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache

def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.

Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)

def _remove_node(self, node_id):
"""
Remove a node from the cache.

Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]

def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass

def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.

Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result
31 changes: 19 additions & 12 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input

class ExecutionResult(Enum):
Expand Down Expand Up @@ -60,26 +60,32 @@ def get(self, node_id):
return self.is_changed[node_id]

class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
def __init__(self, lru_size=None, cache_none=False):
if cache_none:
self.init_dependency_aware_cache()
elif lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]

# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)

def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
Expand Down Expand Up @@ -414,13 +420,14 @@ def pre_execute_cb(call_index):
return (ExecutionResult.SUCCESS, None, None)

class PromptExecutor:
def __init__(self, server, lru_size=None):
def __init__(self, server, lru_size=None, cache_none=False):
self.lru_size = lru_size
self.cache_none = cache_none
self.server = server
self.reset()

def reset(self):
self.caches = CacheSet(self.lru_size)
self.caches = CacheSet(self.lru_size, self.cache_none)
self.status_messages = []
self.success = True

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def cuda_malloc_warning():

def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
Expand Down
0