8000 Add support for SAM2.1 models and batched prompt inputs by jamjamjon · Pull Request #89 · jamjamjon/usls · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add support for SAM2.1 models and batched prompt inputs #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions examples/sam/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct Args {
scale: String,

/// SAM kind
#[argh(option, default = "String::from(\"sam\")")]
#[argh(option, default = "String::from(\"samhq\")")]
kind: String,
}

Expand Down Expand Up @@ -69,9 +69,19 @@ fn main() -> Result<()> {
// Prompt
let prompts = vec![
SamPrompt::default()
// .with_postive_point(500., 375.), // postive point
// .with_negative_point(774., 366.), // negative point
.with_bbox(215., 297., 643., 459.), // bbox
// // # demo: point + point
// .with_positive_point(500., 375.) // mid window
// .with_positive_point(1125., 625.), // car door
// // # demo: bbox
// .with_xyxy(425., 600., 700., 875.), // left wheel
// // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored.
// .with_xyxy(75., 275., 1725., 850.)
// .with_xyxy(425., 600., 700., 875.)
// .with_xyxy(1240., 675., 1400., 750.)
// .with_xyxy(1375., 550., 1650., 800.)
// # demo: bbox + negative point
.with_xyxy(425., 600., 700., 875.) // left wheel
.with_negative_point(575., 750.), // tire
];

// Run & Annotate
Expand Down
6 changes: 6 additions & 0 deletions examples/sam2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
## Quick Start

```Shell

cargo run -r -F cuda --example sam -- --device cuda --scale t
```
93 changes: 93 additions & 0 deletions examples/sam2/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2},
Annotator, DataLoader, Options, Scale,
};

#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,

/// scale
#[argh(option, default = "String::from(\"t\")")]
scale: String,
}

fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();

let args: Args = argh::from_env();

// Build model
let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
Scale::T => (
Options::sam2_1_tiny_encoder(),
Options::sam2_1_tiny_decoder(),
),
Scale::S => (
Options::sam2_1_small_encoder(),
Options::sam2_1_small_decoder(),
),
Scale::B => (
Options::sam2_1_base_plus_encoder(),
Options::sam2_1_base_plus_decoder(),
),
Scale::L => (
Options::sam2_1_large_encoder(),
Options::sam2_1_large_decoder(),
),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
};

let options_encoder = options_encoder
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let options_decoder = options_decoder
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut model = SAM2::new(options_encoder, options_decoder)?;

// Load image
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;

// Prompt
let prompts = vec![SamPrompt::default()
// // # demo: point + point
// .with_positive_point(500., 375.) // mid window
// .with_positive_point(1125., 625.), // car door
// // # demo: bbox
// .with_xyxy(425., 600., 700., 875.), // left wheel
// // # demo: bbox + negative point
// .with_xyxy(425., 600., 700., 875.) // left wheel
// .with_negative_point(575., 750.), // tire
// # demo: multiple objects with boxes
.with_xyxy(75., 275., 1725., 850.)
.with_xyxy(425., 600., 700., 875.)
.with_xyxy(1375., 550., 1650., 800.)
.with_xyxy(1240., 675., 1400., 750.)];

// Run & Annotate
let ys = model.forward(&xs, &prompts)?;

// annotate
let annotator = Annotator::default()
.with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true));

for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None))
.display(),
))?;
}

Ok(())
}
73 changes: 0 additions & 73 deletions examples/yolo-sam/main.rs

This file was deleted.

File renamed without changes.
79 changes: 79 additions & 0 deletions examples/yolo-sam2/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2, YOLO},
Annotator, DataLoader, Options, Scale, Style,
};

#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
}

fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();

let args: Args = argh::from_env();

// build SAM
let (options_encoder, options_decoder) = (
Options::sam2_1_tiny_encoder().commit()?,
Options::sam2_1_tiny_decoder().commit()?,
);
let mut sam = SAM2::new(options_encoder, options_decoder)?;

// build YOLOv8
let options_yolo = Options::yolo_detect()
.with_model_scale(Scale::N)
.with_model_version(8.into())
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut yolo = YOLO::new(options_yolo)?;

// load one image
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;

// build annotator
let annotator = Annotator::default()
.with_polygon_style(
Style::polygon()
.with_visible(true)
.with_text_visible(true)
.show_id(true)
.show_name(true),
)
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));

// run & annotate
let ys_det = yolo.forward(&xs)?;
for y_det in ys_det.iter() {
if let Some(hbbs) = y_det.hbbs() {
// collect hhbs
let mut prompt = SamPrompt::default();
for hbb in hbbs {
prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
}

// sam2 infer
let ys_sam = sam.forward(&xs, &[prompt])?;

// annotate
for (x, y) in xs.iter().zip(ys_sam.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", "YOLO-SAM2"])?
.join(usls:: 9E81 timestamp(None))
.display(),
))?;
}
}
}

Ok(())
}
2 changes: 2 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod rfdetr;
mod rtdetr;
mod rtmo;
mod sam;
mod sam2;
mod sapiens;
mod slanet;
mod smolvlm;
Expand All @@ -49,6 +50,7 @@ pub use rfdetr::*;
pub use rtdetr::*;
pub use rtmo::*;
pub use sam::*;
pub use sam2::*;
pub use sapiens::*;
pub use slanet::*;
pub use smolvlm::*;
Expand Down
Loading
0