class documentation

Wrapper for evaluating a Torch model on images.

See :ref:`this page <model-zoo-custom-models>` for example usage.

Parameters
configan TorchImageModelConfig
Method __enter__ Undocumented
Method __exit__ Undocumented
Method __init__ Undocumented
Method predict Performs prediction on the given image.
Method predict_all Performs prediction on the given batch of images.
Method preprocess.setter Undocumented
Instance Variable config Undocumented
Property classes The list of class labels for the model, if known.
Property device The torch:torch.torch.device that the model is using.
Property has_logits Whether this instance can generate logits.
Property mask_targets The mask targets for the model, if any.
Property media_type The media type processed by the model.
Property num_classes The number of classes for the model, if known.
Property preprocess Whether to apply preprocessing transforms for inference, if any.
Property ragged_batches Whether transforms may return tensors of different sizes. If True, then passing ragged lists of images to predict_all may not be not allowed.
Property skeleton The keypoint skeleton for the model, if any.
Property transforms A torchvision.transforms function that will be applied to each input before prediction, if any.
Property using_gpu Whether the model is using GPU.
Property using_half_precision Whether the model is using half precision.
Method _build_output_processor Undocumented
Method _build_transforms Undocumented
Method _download_model Undocumented
Method _forward_pass Undocumented
Method _load_model Undocumented
Method _load_transforms Undocumented
Method _parse_classes Undocumented
Method _parse_mask_targets Undocumented
Method _parse_skeleton Undocumented
Method _predict_all Applies a forward pass to the given iterable of data and returns the raw model output with no processing applied.
Instance Variable _benchmark_orig Undocumented
Instance Variable _classes Undocumented
Instance Variable _device Undocumented
Instance Variable _mask_targets Undocumented
Instance Variable _model Undocumented
Instance Variable _no_grad Undocumented
Instance Variable _output_processor Undocumented
Instance Variable _preprocess Undocumented
Instance Variable _ragged_batches Undocumented
Instance Variable _skeleton Undocumented
Instance Variable _transforms Undocumented
Instance Variable _using_gpu Undocumented
Instance Variable _using_half_precision Undocumented

Inherited from TorchEmbeddingsMixin:

Method embed Generates an embedding for the given data.
Method embed_all Generates embeddings for the given iterable of data.
Method get_embeddings Returns the embeddings generated by the last forward pass of the model.
Property has_embeddings Whether this instance has embeddings.
Instance Variable _as_feature_extractor Undocumented
Instance Variable _embeddings_layer Undocumented

Inherited from LogitsMixin (via TorchEmbeddingsMixin, EmbeddingsMixin, TorchModelMixin):

Method store_logits.setter Undocumented
Property store_logits Whether the model should store logits in its predictions.
Instance Variable _store_logits Undocumented

Inherited from Model (via TorchEmbeddingsMixin, EmbeddingsMixin, TorchModelMixin, LogitsMixin):

Property can_embed_prompts Whether this instance can generate prompt embeddings.
def __enter__(self): (source)

Undocumented

def __exit__(self, *args): (source)

Undocumented

def predict(self, img): (source)

Performs prediction on the given image.

Parameters
img

the image to process, which can be any of the following:

  • A PIL image
  • A uint8 numpy array (HWC)
  • A Torch tensor (CHW)
Returns
a fiftyone.core.labels.Label instance or dict of fiftyone.core.labels.Label instances containing the predictions
def predict_all(self, imgs): (source)

Performs prediction on the given batch of images.

Parameters
imgs

the batch of images to process, which can be any of the following:

  • A list of PIL images
  • A list of uint8 numpy arrays (HWC)
  • A list of Torch tensors (CHW)
  • A uint8 numpy tensor (NHWC)
  • A Torch tensor (NCHW)
Returns
a list of fiftyone.core.labels.Label instances or a list of dicts of fiftyone.core.labels.Label instances containing the predictions
@preprocess.setter
def preprocess(self, value): (source)

Undocumented

The list of class labels for the model, if known.

The torch:torch.torch.device that the model is using.

Whether this instance can generate logits.

@property
mask_targets = (source)

The mask targets for the model, if any.

The media type processed by the model.

@property
num_classes = (source)

The number of classes for the model, if known.

Whether to apply preprocessing transforms for inference, if any.

@property
ragged_batches = (source)

Whether transforms may return tensors of different sizes. If True, then passing ragged lists of images to predict_all may not be not allowed.

The keypoint skeleton for the model, if any.

A torchvision.transforms function that will be applied to each input before prediction, if any.

Whether the model is using GPU.

@property
using_half_precision = (source)

Whether the model is using half precision.

def _build_output_processor(self, config): (source)

Undocumented

def _build_transforms(self, config): (source)

Undocumented

def _forward_pass(self, imgs): (source)

Undocumented

def _load_transforms(self, config): (source)

Undocumented

def _parse_classes(self, config): (source)

Undocumented

def _parse_mask_targets(self, config): (source)

Undocumented

def _parse_skeleton(self, config): (source)

Undocumented

def _predict_all(self, imgs): (source)

Applies a forward pass to the given iterable of data and returns the raw model output with no processing applied.

Parameters
imgsUndocumented
argsan iterable of data. See predict_all for details
Returns
the raw output of the model
_benchmark_orig = (source)

Undocumented

_classes = (source)

Undocumented

Undocumented

_mask_targets = (source)

Undocumented

_no_grad = (source)

Undocumented

_output_processor = (source)

Undocumented

_preprocess = (source)

Undocumented

_ragged_batches = (source)

Undocumented

_skeleton = (source)

Undocumented

_transforms = (source)

Undocumented

_using_gpu = (source)

Undocumented

_using_half_precision = (source)

Undocumented