-
Notifications
You must be signed in to change notification settings - Fork 56
Is there Text to video alignment metric for custom videos? #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Even I am getting the same issue, with the same dimensions. |
Hey so I found a way to hack the code for
import os
import importlib
import spacy
from itertools import chain
from pathlib import Path
from .utils import get_prompt_from_filename, init_submodules, save_json, load_json
from .distributed import get_rank, print0
# Load spaCy model for object extraction
nlp = spacy.load("en_core_web_sm")
def extract_objects_from_prompt(prompt):
"""
Extracts objects (nouns) from a given text prompt.
Returns a list of extracted objects.
"""
doc = nlp(prompt)
objects = [token.text for token in doc if token.pos_ in ("NOUN", "PROPN")] # Extract nouns & proper nouns
return objects[:2] if len(objects) >= 2 else ["unknown_object1", "unknown_object2"] # Fallback if needed
class VBench:
def __init__(self, device, full_info_dir, output_path):
self.device = device
self.full_info_dir = full_info_dir
self.output_path = output_path
os.makedirs(self.output_path, exist_ok=True)
def build_full_dimension_list(self):
return [
"subject_consistency", "background_consistency", "aesthetic_quality",
"imaging_quality", "object_class", "multiple_objects", "color",
"spatial_relationship", "scene", "temporal_style", "overall_consistency",
"human_action", "temporal_flickering", "motion_smoothness", "dynamic_degree",
"appearance_style"
]
def check_dimension_requires_extra_info(self, dimension_list):
dim_custom_not_supported = set(dimension_list) & {
'object_class', 'scene', 'appearance_style', 'color', 'spatial_relationship'
}
assert len(dim_custom_not_supported) == 0, f"Dimensions {dim_custom_not_supported} not supported for custom input."
def build_full_info_json(self, videos_path, name, dimension_list, prompt_list=[], special_str='', verbose=False, mode='vbench_standard', **kwargs):
"""
Builds the evaluation JSON file, ensuring proper handling of custom inputs.
"""
cur_full_info_list = []
if mode == 'custom_input':
self.check_dimension_requires_extra_info(dimension_list)
if os.path.isfile(videos_path):
prompt = get_prompt_from_filename(videos_path)
extracted_objects = extract_objects_from_prompt(prompt)
cur_full_info_list.append({
"prompt_en": prompt_list[0] if len(prompt_list) == 1 else prompt,
"dimension": dimension_list,
"video_list": [videos_path],
"auxiliary_info": {"object": " and ".join(extracted_objects)}
})
else:
video_names = os.listdir(videos_path)
for filename in video_names:
postfix = Path(os.path.join(videos_path, filename)).suffix
if postfix.lower() not in ['.mp4', '.gif', '.jpg', '.png']:
continue
prompt = get_prompt_from_filename(filename)
extracted_objects = extract_objects_from_prompt(prompt)
cur_full_info_list.append({
"prompt_en": prompt,
"dimension": dimension_list,
"video_list": [os.path.join(videos_path, filename)],
"auxiliary_info": {"object": " and ".join(extracted_objects)}
})
if len(prompt_list) > 0:
prompt_list = {os.path.join(videos_path, path): prompt_list[path] for path in prompt_list}
assert len(prompt_list) >= len(cur_full_info_list), """
Number of prompts should match the number of videos.
To read the prompt from filename, delete --prompt_file and --prompt_list
"""
all_video_path = [os.path.abspath(file) for file in chain.from_iterable(vid["video_list"] for vid in cur_full_info_list)]
missing_videos = set(all_video_path) - set([os.path.abspath(path_key) for path_key in prompt_list])
assert len(missing_videos) == 0, f"Prompts for the following videos are missing:\n{chr(10).join(missing_videos)}"
video_map = {os.path.abspath(prompt_key): prompt_list[prompt_key] for prompt_key in prompt_list}
for video_info in cur_full_info_list:
video_info["prompt_en"] = video_map[os.path.abspath(video_info["video_list"][0])]
elif mode == 'vbench_category':
self.check_dimension_requires_extra_info(dimension_list)
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
category_supported = [Path(category).stem for category in os.listdir('prompts/prompts_per_category')]
category = kwargs.get('category', category_supported)
assert category is not None, "Please specify the category to be evaluated with --category"
assert category in category_supported, f"Category '{category}' is not supported."
video_names = os.listdir(videos_path)
postfix = Path(video_names[0]).suffix
with open(f'{CUR_DIR}/prompts_per_category/{category}.txt', 'r') as f:
video_prompts = [line.strip() for line in f.readlines()]
for prompt in video_prompts:
video_list = []
for filename in video_names:
if not Path(filename).stem.startswith(prompt):
continue
postfix = Path(os.path.join(videos_path, filename)).suffix
if postfix.lower() not in ['.mp4', '.gif', '.jpg', '.png']:
continue
video_list.append(os.path.join(videos_path, filename))
cur_full_info_list.append({
"prompt_en": prompt,
"dimension": dimension_list,
"video_list": video_list
})
else:
full_info_list = load_json(self.full_info_dir)
video_names = os.listdir(videos_path)
postfix = Path(video_names[0]).suffix
for prompt_dict in full_info_list:
if set(dimension_list) & set(prompt_dict["dimension"]):
prompt = prompt_dict['prompt_en']
prompt_dict['video_list'] = []
for i in range(5):
intended_video_name = f'{prompt}{special_str}-{str(i)}{postfix}'
if intended_video_name in video_names:
intended_video_path = os.path.join(videos_path, intended_video_name)
prompt_dict['video_list'].append(intended_video_path)
if verbose:
print0(f'Successfully found video: {intended_video_name}')
else:
print0(f'WARNING!!! Missing video: {intended_video_name}')
cur_full_info_list.append(prompt_dict)
cur_full_info_path = os.path.join(self.output_path, name + '_full_info.json')
save_json(cur_full_info_list, cur_full_info_path)
print0(f'Evaluation metadata saved to {cur_full_info_path}')
return cur_full_info_path
def evaluate(self, videos_path, name, prompt_list=[], dimension_list=None, local=False, read_frame=False, mode='vbench_standard', **kwargs):
results_dict = {}
if dimension_list is None:
dimension_list = self.build_full_dimension_list()
submodules_dict = init_submodules(dimension_list, local=local, read_frame=read_frame)
cur_full_info_path = self.build_full_info_json(videos_path, name, dimension_list, prompt_list, mode=mode, **kwargs)
for dimension in dimension_list:
try:
dimension_module = importlib.import_module(f'vbench.{dimension}')
evaluate_func = getattr(dimension_module, f'compute_{dimension}')
except Exception as e:
raise NotImplementedError(f'Unimplemented dimension {dimension}!, {e}')
submodules_list = submodules_dict[dimension]
print0(f'cur_full_info_path: {cur_full_info_path}')
results = evaluate_func(cur_full_info_path, self.device, submodules_list, **kwargs)
results_dict[dimension] = results
output_name = os.path.join(self.output_path, name + '_eval_results.json')
if get_rank() == 0:
save_json(results_dict, output_name)
print0(f'Evaluation results saved to {output_name}')
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.grit_model import DenseCaptioning
from torchvision import transforms
import logging
from .distributed import (
get_world_size,
get_rank,
all_gather,
barrier,
distribute_list_to_rank,
gather_list_of_dict,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def get_dect_from_grit(model, image_arrays):
""" Extract object detections from video frames. """
pred = []
if not isinstance(image_arrays, list):
image_arrays = image_arrays.numpy()
with torch.no_grad():
for frame in image_arrays:
ret = model.run_caption_tensor(frame)
if len(ret[0]) > 0:
pred.append(set(ret[0][0][2])) # Extract object set
else:
pred.append(set([]))
return pred
def check_generate(key_objects, predictions):
""" Check if all key objects appear in frame predictions. """
cur_cnt = 0
key_objects = set(key_objects) # Ensure it's a set
for pred in predictions:
if key_objects.issubset(pred): # Check if all key objects are present
cur_cnt += 1
return cur_cnt
def multiple_objects(model, video_dict, device):
""" Evaluate multiple object detection in videos. """
success_frame_count, frame_count = 0, 0
video_results = []
for info in tqdm(video_dict, disable=get_rank() > 0):
if 'auxiliary_info' not in info:
logger.error("Auxiliary info is missing in JSON: %s", info)
raise Exception("Auxiliary info is not in json, please check your json.")
object_info = info['auxiliary_info'].get('object', [])
if isinstance(object_info, str):
object_info = object_info.replace(" and ", ",").split(",") # Convert to list
object_info = [obj.strip() for obj in object_info] # Ensure clean object names
for video_path in info['video_list']:
video_tensor = load_video(video_path, num_frames=16)
if video_tensor is None:
logger.warning("Failed to load video: %s", video_path)
continue
_, _, h, w = video_tensor.size()
if min(h, w) > 768:
scale = 720. / min(h, w)
video_tensor = transforms.Resize((int(scale * h), int(scale * w)))(video_tensor)
cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0, 2, 3, 1))
cur_success_frame_count = check_generate(object_info, cur_video_pred)
cur_success_frame_rate = cur_success_frame_count / len(cur_video_pred)
success_frame_count += cur_success_frame_count
frame_count += len(cur_video_pred)
video_results.append({
'video_path': video_path,
'video_results': cur_success_frame_rate,
'success_frame_count': cur_success_frame_count,
'frame_count': len(cur_video_pred)
})
success_rate = success_frame_count / frame_count if frame_count > 0 else 0
return success_rate, video_results
def compute_multiple_objects(json_dir, device, submodules_dict, **kwargs):
""" Compute multiple object metric using Dense Captioning model. """
dense_caption_model = DenseCaptioning(device)
dense_caption_model.initialize_model_det(**submodules_dict)
logger.info("Initialize detection model success")
# Load JSON correctly
with open(json_dir, "r") as f:
prompt_dict_ls = json.load(f)
# Debugging: Check JSON structure
logger.info("DEBUG: First 3 JSON entries: %s", prompt_dict_ls[:3])
# Ensure JSON format is valid
for entry in prompt_dict_ls:
if 'dimension' in entry and "multiple_objects" not in entry['dimension']:
logger.warning("Skipping entry without multiple_objects dimension: %s", entry)
continue
prompt_dict_ls = distribute_list_to_rank(prompt_dict_ls)
all_results, video_results = multiple_objects(dense_caption_model, prompt_dict_ls, device)
if get_world_size() > 1:
video_results = gather_list_of_dict(video_results)
success_frame_count = sum(x['success_frame_count'] for x in video_results)
frame_count = sum(x['frame_count'] for x in video_results)
all_results = success_frame_count / frame_count if frame_count > 0 else 0
return all_results, video_results
|
This does not work properly, because i am getting 0 for most video results. Hopefully the original repo authors can incorporate in future iterations. |
Hi, thanks for the question! Custom video evaluation isn’t supported for those dimensions because the evaluation pipeline relies on closed-domain detectors or classifiers that are specifically trained on predefined concepts. This design choice ensures a high standard of evaluation reliability and accuracy, given the open-domain evaluation tools available at the time. In VBench-2.0, we’ve incorporated more open-domain evaluation dimensions and tools that are less constrained by a closed-domain concept set. |
We generated some custom videos with our own custom text prompts. We would like to know how the video aligns with the text prompts. However, these dimensions are not supported:
AssertionError: dimensions : {'multiple_objects', 'spatial_relationship', 'object_class', 'color', 'appearance_style', 'scene'} not supported for custom input
What is the reason that these dimensions are not supported and is there a way that we can hack the code to force it to compute a metric using our generated videos from custom text prompts?
The text was updated successfully, but these errors were encountered: