class documentation

Wrapper for running Segment Anything inference.

Box prompt example:

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "quickstart", max_samples=25, shuffle=True, seed=51
)

model = foz.load_zoo_model("segment-anything-vitb-torch")

# Prompt with boxes
dataset.apply_model(
    model,
    label_field="segmentations",
    prompt_field="ground_truth",
)

session = fo.launch_app(dataset)

Keypoint prompt example:

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "coco-2017",
    split="validation",
    label_types="detections",
    classes=["person"],
    max_samples=25,
    only_matching=True,
)

# Generate some keypoints
model = foz.load_zoo_model("keypoint-rcnn-resnet50-fpn-coco-torch")
dataset.default_skeleton = model.skeleton
dataset.apply_model(model, label_field="gt")

model = foz.load_zoo_model("segment-anything-vitb-torch")

# Prompt with keypoints
dataset.apply_model(
    model,
    label_field="segmentations",
    prompt_field="gt_keypoints",
)

session = fo.launch_app(dataset)

Automatic segmentation example:

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "quickstart", max_samples=5, shuffle=True, seed=51
)

model = foz.load_zoo_model("segment-anything-vitb-torch")

# Automatic segmentation
dataset.apply_model(model, label_field="auto")

session = fo.launch_app(dataset)
Parameters
configa SegmentAnythingModelConfig
Method __init__ Undocumented
Method predict_all Performs prediction on the given iterable of data.
Method _download_model Undocumented
Method _forward_pass Undocumented
Method _forward_pass_auto Undocumented
Method _forward_pass_boxes Undocumented
Method _forward_pass_points Undocumented
Method _get_classes Undocumented
Method _get_field Undocumented
Method _get_prompt_type Undocumented
Method _get_prompts Undocumented
Method _load_auto_generator Undocumented
Method _load_model Undocumented
Method _load_predictor Undocumented
Method _parse_samples Undocumented
Instance Variable _curr_classes Undocumented
Instance Variable _curr_prompt_type Undocumented
Instance Variable _curr_prompts Undocumented
Instance Variable _output_processor Undocumented

Inherited from TorchSamplesMixin:

Method predict Performs prediction on the given data.

Inherited from SamplesMixin (via TorchSamplesMixin):

Method needs_fields.setter Undocumented
Property needs_fields A dict mapping model-specific keys to sample field names.
Instance Variable _fields Undocumented

Inherited from TorchImageModel (via TorchSamplesMixin, SamplesMixin):

Method __enter__ Undocumented
Method __exit__ Undocumented
Method preprocess.setter Undocumented
Instance Variable config Undocumented
Property classes The list of class labels for the model, if known.
Property device The torch:torch.torch.device that the model is using.
Property has_logits Whether this instance can generate logits.
Property mask_targets The mask targets for the model, if any.
Property media_type The media type processed by the model.
Property num_classes The number of classes for the model, if known.
Property preprocess Whether to apply preprocessing transforms for inference, if any.
Property ragged_batches Whether transforms may return tensors of different sizes. If True, then passing ragged lists of images to predict_all may not be not allowed.
Property skeleton The keypoint skeleton for the model, if any.
Property transforms A torchvision.transforms function that will be applied to each input before prediction, if any.
Property using_gpu Whether the model is using GPU.
Property using_half_precision Whether the model is using half precision.
Method _build_output_processor Undocumented
Method _build_transforms Undocumented
Method _load_transforms Undocumented
Method _parse_classes Undocumented
Method _parse_mask_targets Undocumented
Method _parse_skeleton Undocumented
Method _predict_all Applies a forward pass to the given iterable of data and returns the raw model output with no processing applied.
Instance Variable _benchmark_orig Undocumented
Instance Variable _classes Undocumented
Instance Variable _device Undocumented
Instance Variable _mask_targets Undocumented
Instance Variable _model Undocumented
Instance Variable _no_grad Undocumented
Instance Variable _preprocess Undocumented
Instance Variable _ragged_batches Undocumented
Instance Variable _skeleton Undocumented
Instance Variable _transforms Undocumented
Instance Variable _using_gpu Undocumented
Instance Variable _using_half_precision Undocumented

Inherited from TorchEmbeddingsMixin (via TorchSamplesMixin, SamplesMixin, TorchImageModel):

Method embed Generates an embedding for the given data.
Method embed_all Generates embeddings for the given iterable of data.
Method get_embeddings Returns the embeddings generated by the last forward pass of the model.
Property has_embeddings Whether this instance has embeddings.
Instance Variable _as_feature_extractor Undocumented
Instance Variable _embeddings_layer Undocumented

Inherited from LogitsMixin (via TorchSamplesMixin, SamplesMixin, TorchImageModel, TorchEmbeddingsMixin, EmbeddingsMixin, TorchModelMixin):

Method store_logits.setter Undocumented
Property store_logits Whether the model should store logits in its predictions.
Instance Variable _store_logits Undocumented

Inherited from Model (via TorchSamplesMixin, SamplesMixin, TorchImageModel, TorchEmbeddingsMixin, EmbeddingsMixin, TorchModelMixin, LogitsMixin):

Property can_embed_prompts Whether this instance can generate prompt embeddings.
def predict_all(self, imgs, samples=None): (source)

Performs prediction on the given iterable of data.

Image models should support, at minimum, processing args values that are either lists of uint8 numpy arrays (HWC) or numpy array tensors (NHWC).

Video models should support, at minimum, processing args values that are lists of eta.core.video.VideoReader instances.

Subclasses can override this method to increase efficiency, but, by default, this method simply iterates over the data and applies predict to each.

Parameters
imgsUndocumented
samples:Nonean iterable of fiftyone.core.sample.Sample instances associated with the data
argsan iterable of data
Returns
a list of fiftyone.core.labels.Label instances or a list of dicts of fiftyone.core.labels.Label instances containing the predictions
def _download_model(self, config): (source)
def _forward_pass(self, imgs): (source)
def _forward_pass_auto(self, imgs): (source)

Undocumented

def _forward_pass_boxes(self, imgs): (source)
def _forward_pass_points(self, imgs): (source)

Undocumented

def _get_classes(self, samples, field_name): (source)

Undocumented

def _get_field(self): (source)

Undocumented

def _get_prompt_type(self, samples, field_name): (source)

Undocumented

def _get_prompts(self, samples, field_name): (source)

Undocumented

def _load_auto_generator(self): (source)
def _load_predictor(self): (source)
def _parse_samples(self, samples, field_name): (source)

Undocumented

_curr_classes = (source)

Undocumented

_curr_prompt_type = (source)

Undocumented

_curr_prompts = (source)

Undocumented