8000 Add BinaryTreePartition class for hierarchical image-pair partitioning by tiantianxiangshang629 · Pull Request #840 · borglab/gtsfm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add BinaryTreePartition class for hierarchical image-pair partitioning #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions gtsfm/graph_partitioner/binary_tree_partition.py
10000
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Implementation of a binary tree graph partitioner.

This partitioner recursively partitions image pair graphs into a binary tree
structure up to a specified depth, using METIS-based ordering. Leaf nodes
represent explicit image keys and associated edge groupings.

Authors: Shicong Ma
"""

from typing import Dict, List, Tuple

import gtsam
import networkx as nx
from gtsam import SymbolicFactorGraph

import gtsfm.utils.logger as logger_utils
from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase

logger = logger_utils.get_logger()


class BinaryTreeNode:
"""Node class for a binary tree representing partitioned sets of image keys."""

def __init__(self, keys: List[int], depth: int):
"""
Initialize a BinaryTreeNode.

Args:
keys: Image indices at this node (only populated at leaf level).
depth: Depth level in the binary tree.
"""
self.keys = keys # Only at leaves
self.left = None
self.right = None
self.depth = depth

def is_leaf(self) -> bool:
"""Check whether this node is a leaf node."""
return self.left is None and self.right is None


class BinaryTreePartition(GraphPartitionerBase):
"""Graph partitioner that uses a binary tree to recursively divide image pairs."""

def __init__(self, max_depth: int = 2):
"""
Initialize the BinaryTreePartition object.

Args:
max_depth: Maximum depth of the binary tree; results in 2^depth partitions.
"""
super().__init__(process_name="BinaryTreePartition")
self.max_depth = max_depth

def partition_image_pairs(self, image_pairs: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
"""Partition image pairs into subgroups using a binary tree.

Args:
image_pairs: List of image index pairs (i, j), where i < j.

Returns:
A list of image pair subsets, one for each leaf in the binary tree.
"""
if not image_pairs:
logger.warning("No image pairs provided for partitioning.")
return []

symbol_graph, _, nx_graph = self._build_graphs(image_pairs)
ordering = gtsam.Ordering.MetisSymbolicFactorGraph(symbol_graph)
binary_tree_root_node = self._build_binary_partition(ordering)

num_leaves = 2**self.max_depth
image_pairs_per_partition = [[] for _ in range(num_leaves)]

partition_details = self._compute_leaf_partition_details(binary_tree_root_node, nx_graph)

logger.info(f"BinaryTreePartition: partitioned into {len(partition_details)} leaf nodes.")

for i in range(num_leaves):
edges_explicit = partition_details[i].get("edges_within_explicit", [])
edges_shared = partition_details[i].get("edges_with_shared", [])
image_pairs_per_partition[i] = edges_explicit + edges_shared

for i, part in enumerate(partition_details):
explicit_keys = part.get("explicit_keys", [])
edges_within = part.get("edges_within_explicit", [])
edges_shared = part.get("edges_with_shared", [])

logger.info(
f"Partition {i}:\n"
f" Explicit Image Keys that only exist within the current partition "
f"({len(explicit_keys)}): {sorted(explicit_keys)}\n"
f" Internal Edges ({len(edges_within)}): {edges_within}\n"
f" Shared Edges ({len(edges_shared)}): {edges_shared}\n"
)

return image_pairs_per_partition

def _build_graphs(self, image_pairs: List[Tuple[int, int]]) -> Tuple[SymbolicFactorGraph, List[int], nx.Graph]:
"""Construct GTSAM and NetworkX graphs from image pairs.

Args:
image_pairs: List of image index pairs.

Returns:
A tuple of (SymbolicFactorGraph, list of keys, NetworkX graph).
"""
sfg = gtsam.SymbolicFactorGraph()
nxg = nx.Graph()
keys = set()

for i, j in image_pairs:
key_i = gtsam.symbol("x", i)
key_j = gtsam.symbol("x", j)
keys.add(key_i)
keys.add(key_j)

sfg.push_factor(key_i, key_j)
nxg.add_edge(key_i, key_j)

return sfg, list(keys), nxg

def _build_binary_partition(self, ordering: gtsam.Ordering) -> BinaryTreeNode:
"""Build a binary tree of image keys based on a given ordering.

Args:
ordering: GTSAM Ordering object created via METIS.

Returns:
Root node of the binary tree.
"""
ordered_keys = [ordering.at(i) for i in range(ordering.size())]

def split(keys: List[int], depth: int) -> BinaryTreeNode:
if depth == self.max_depth:
return BinaryTreeNode(keys, depth)

mid = len(keys) // 2
left_node = split(keys[:mid], depth + 1)
right_node = split(keys[mid:], depth + 1)
node = BinaryTreeNode([], depth)
node.left = left_node
node.right = right_node
return node

return split(ordered_keys, 0)

def _compute_leaf_partition_details(
self,
node: BinaryTreeNode,
nx_graph: nx.Graph,
) -> List[Dict]:
"""Recursively traverse the binary tree and return partition details per leaf.

Args:
node: Current binary tree node being processed.
nx_graph: NetworkX graph built from image pairs.

Returns:
A list of dictionaries containing partition details per leaf node.
"""
if node.is_leaf():
explicit_keys = set(node.keys)
return [
{
"explicit_keys": [gtsam.Symbol(u).index() for u in explicit_keys],
"explicit_count": len(explicit_keys),
"edges_within_explicit": [
(gtsam.Symbol(u).index(), gtsam.Symbol(v).index())
for u, v in nx_graph.edges()
if u in explicit_keys and v in explicit_keys
],
"edges_with_shared": [], # placeholder
}
]

# Recursively compute for children
left_partitions = self._compute_leaf_partition_details(node.left, nx_graph)
right_partitions = self._compute_leaf_partition_details(node.right, nx_graph)

if node.left.is_leaf() and node.right.is_leaf():
left_keys = set(node.left.keys)
right_keys = set(node.right.keys)

shared_edges = [
(gtsam.Symbol(u).index(), gtsam.Symbol(v).index())
for u, v in nx_graph.edges()
if (u in left_keys and v in right_keys) or (u in right_keys and v in left_keys)
]

# Directly assign shared edges to the only two leaf partitions
left_partitions[0]["edges_with_shared"] = shared_edges
right_partitions[0]["edges_with_shared"] = shared_edges

return left_partitions + right_partitions
31 changes: 15 additions & 16 deletions gtsfm/runner/gtsfm_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,27 @@
import dask
import hydra
import numpy as np

from dask import config as dask_config
from dask.distributed import Client, LocalCluster, SSHCluster, performance_report
from gtsam import Rot3, Pose3, Unit3
from gtsam import Pose3, Rot3, Unit3
from hydra.utils import instantiate
from omegaconf import OmegaConf

import gtsfm.evaluation.metrics_report as metrics_report
import gtsfm.utils.merging as merging_utils
import gtsfm.utils.logger as logger_utils
import gtsfm.utils.merging as merging_utils
import gtsfm.utils.metrics as metrics_utils
import gtsfm.utils.viz as viz_utils
from gtsfm.common.gtsfm_data import GtsfmData
from gtsfm import two_view_estimator
from gtsfm.common.gtsfm_data import GtsfmData
from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup
from gtsfm.frontend.correspondence_generator.image_correspondence_generator import ImageCorrespondenceGenerator
from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase
from gtsfm.loader.loader_base import LoaderBase
from gtsfm.retriever.retriever_base import ImageMatchingRegime
from gtsfm.scene_optimizer import SceneOptimizer
from gtsfm.two_view_estimator import TWO_VIEW_OUTPUT, TwoViewEstimationReport, run_two_view_estimator_as_futures
from gtsfm.ui.process_graph_generator import ProcessGraphGenerator
from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase
from gtsfm.graph_partitioner.single_partition import SinglePartition
from gtsfm.utils.subgraph_utils import group_results_by_subgraph

dask_config.set({"distributed.scheduler.worker-ttl": None})
Expand All @@ -55,6 +53,7 @@ def __init__(self, override_args=None) -> None:

self.loader: LoaderBase = self.construct_loader()
self.scene_optimizer: SceneOptimizer = self.construct_scene_optimizer()
self.graph_partitioner: GraphPartitionerBase = self.scene_optimizer.graph_partitioner

def construct_argparser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=self.tag)
Expand Down Expand Up @@ -283,14 +282,10 @@ def setup_ssh_cluster_with_retries(self) -> SSHCluster:
)
return cluster

def run(self, graph_partitioner: GraphPartitionerBase = None) -> GtsfmData:
def run(self) -> GtsfmData:
"""Run the SceneOptimizer."""
start_time = time.time()

# Create graph partitioner if not provided
if graph_partitioner is None:
graph_partitioner = SinglePartition()

# Create dask cluster.
if self.parsed_args.cluster_config:
cluster = self.setup_ssh_cluster_with_retries()
Expand Down Expand Up @@ -322,7 +317,7 @@ def run(self, graph_partitioner: GraphPartitionerBase = None) -> GtsfmData:
client=client,
images=self.loader.get_all_images_as_futures(client),
image_fnames=self.loader.image_filenames(),
plots_output_dir=self.scene_optimizer._plot_base_path,
plots_output_dir=self.scene_optimizer.create_plot_base_path(),
)

retriever_metrics = self.scene_optimizer.image_pairs_generator._retriever.evaluate(
Expand Down Expand Up @@ -394,9 +389,8 @@ def run(self, graph_partitioner: GraphPartitionerBase = None) -> GtsfmData:
all_metrics_groups = [retriever_metrics, two_view_agg_metrics]

# Partition image pairs
subgraphs = graph_partitioner.partition_image_pairs(image_pair_indices)
subgraphs = self.graph_partitioner.partition_image_pairs(image_pair_indices)
logger.info(f"Partitioned into {len(subgraphs)} subgraphs")

# Group results by subgraph
subgraph_two_view_results = group_results_by_subgraph(two_view_results_dict, subgraphs)

Expand All @@ -407,9 +401,14 @@ def run(self, graph_partitioner: GraphPartitionerBase = None) -> GtsfmData:

for idx, subgraph_result_dict in enumerate(subgraph_two_view_results):
logger.info(
f"Creating computation graph for subgraph {idx+1}/{len(subgraph_two_view_results)}"
f"with {len(subgraph_result_dict)} image pairs"
f"Creating computation graph for subgraph {idx + 1}/{len(subgraph_two_view_results)} "
f"with { len(subgraph_result_dict)} image pairs"
)
if len(subgraph_two_view_results) == 1:
# single partition
self.scene_optimizer.create_output_directories(None)
else:
self.scene_optimizer.create_output_directories(idx + 1)

# Unzip the two-view results for this subgraph
subgraph_i2Ri1_dict, subgraph_i2Ui1_dict, subgraph_v_corr_idxs_dict, _, subgraph_post_isp_reports = (
Expand Down
24 changes: 18 additions & 6 deletions gtsfm/scene_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from gtsfm.common.pose_prior import PosePrior
from gtsfm.densify.mvs_base import MVSBase
from gtsfm.frontend.correspondence_generator.correspondence_generator_base import CorrespondenceGeneratorBase
from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase
from gtsfm.graph_partitioner.single_partition import SinglePartition
from gtsfm.multi_view_optimizer import MultiViewOptimizer
from gtsfm.retriever.image_pairs_generator import ImagePairsGenerator
from gtsfm.retriever.retriever_base import ImageMatchingRegime
Expand Down Expand Up @@ -75,6 +77,7 @@ def __init__(
pose_angular_error_thresh: float = 3, # in degrees
output_root: str = DEFAULT_OUTPUT_ROOT,
output_worker: Optional[str] = None,
graph_partitioner: Optional[GraphPartitionerBase] = SinglePartition(),
) -> None:
self.image_pairs_generator = image_pairs_generator
self.correspondence_generator = correspondence_generator
Expand All @@ -90,7 +93,7 @@ def __init__(
self._pose_angular_error_thresh = pose_angular_error_thresh
self.output_root = Path(output_root)
self._output_worker = output_worker
self._create_output_directories()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented a new create_output_directories function to create the partition result directories.

self.graph_partitioner = graph_partitioner

def __repr__(self) -> str:
"""Returns string representation of class."""
Expand All @@ -102,12 +105,21 @@ def __repr__(self) -> str:
DenseMultiviewOptimizer: {self.dense_multiview_optimizer}
"""

def _create_output_directories(self) -> None:
def create_plot_base_path(self):
"""Create plot base path."""
plot_base_path = self.output_root / "plots"
os.makedirs(plot_base_path, exist_ok=True)
return plot_base_path

def create_output_directories(self, partition_index: Optional[int]) -> None:
"""Create various output directories for GTSFM results, metrics, and plots."""
# base paths for storage
self._plot_base_path = self.output_root / "plots"
self._metrics_path = self.output_root / "result_metrics"
self._results_path = self.output_root / "results"
# Construct subfolder if partitioned
partition_folder = f"partition_{partition_index}" if partition_index is not None else ""

# Base paths
self._plot_base_path = self.output_root / "plots" / partition_folder
self._metrics_path = self.output_root / "result_metrics" / partition_folder
self._results_path = self.output_root / "results" / partition_folder

# plot paths
self._plot_correspondence_path = self._plot_base_path / "correspondences"
Expand Down
11 changes: 9 additions & 2 deletions gtsfm/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Authors: Ayush Baid, John Lambert
"""

import glob
import os
import pickle
Expand All @@ -15,7 +16,7 @@
import numpy as np
import open3d
import simplejson as json
from gtsam import Cal3Bundler, Point3, Pose3, Rot3, SfmTrack
from gtsam import Cal3Bundler, Cal3DS2, Point3, Pose3, Rot3, SfmTrack
from PIL import Image as PILImage
from PIL.ExifTags import GPSTAGS, TAGS

Expand Down Expand Up @@ -223,10 +224,16 @@ def colmap2gtsfm(
elif camera_model_name == "RADIAL":
f, cx, cy, k1, k2 = cameras[img.camera_id].params[:5]
fx = f
elif camera_model_name == "OPENCV":
fx, fy, cx, cy, k1, k2, p1, p2 = cameras[img.camera_id].params[:8]
else:
raise ValueError(f"Unsupported COLMAP camera type: {camera_model_name}")

intrinsics_gtsfm.append(Cal3Bundler(fx, k1, k2, cx, cy))
if camera_model_name == "OPENCV":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were you able to run gtsfm with Cal3DS2? I think it was not as good as Cal3Bundler when we have used it before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was. The reason for using it is that Cal3Bundler does not accept those distortion arguments from OPENCV.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you get it to run with Cal3DS2 without adding Cal3DS2 support to the bundle adjustment module? Only Cal3Bundler and Cal3Fisheye are supported I believe.

intrinsics_gtsfm.append(Cal3DS2(fx, fy, 0.0, cx, cy, k1, k2, p1, p2))
else:
intrinsics_gtsfm.append(Cal3Bundler(fx, k1, k2, cx, cy))

image_id_to_idx[img.id] = idx
img_h, img_w = cameras[img.camera_id].height, cameras[img.camera_id].width
img_dims.append((img_h, img_w))
Expand Down
Loading
Loading
0