import warnings
import collections
from functools import partial
import numpy as np
from menpo.base import name_of_callable
from menpo.shape import TriMesh
from menpo.transform import PiecewiseAffine
[docs]def check_diagonal(diagonal):
r"""
Checks that the diagonal length used to normalize the images' size is
``>= 20``.
Parameters
----------
diagonal : `int`
The value to check.
Returns
-------
diagonal : `int`
The value if it's correct.
Raises
------
ValueError
diagonal must be >= 20 or None
"""
if diagonal is not None and diagonal < 20:
raise ValueError("diagonal must be >= 20 or None")
return diagonal
[docs]def check_landmark_trilist(image, transform, group=None):
r"""
Checks that the provided image has a triangulated shape (thus an isntance of
`menpo.shape.TriMesh`) and the transform is `menpo.transform.PiecewiseAffine`
Parameters
----------
image : `menpo.image.Image` or subclass
The input image.
transform : `menpo.transform.PiecewiseAffine`
The transform object.
group : `str` or ``None``, optional
The group of the shape to check.
Raises
------
Warning
The given images do not have an explicit triangulation applied. A
Delaunay Triangulation will be computed and used for warping. This may
be suboptimal and cause warping artifacts.
"""
shape = image.landmarks[group]
check_trilist(shape, transform)
[docs]def check_trilist(shape, transform):
r"""
Checks that the provided shape is triangulated (thus an isntance of
`menpo.shape.TriMesh`) and the transform is `menpo.transform.PiecewiseAffine`
Parameters
----------
shape : `menpo.shape.TriMesh`
The input shape (usually the reference/mean shape of a model).
transform : `menpo.transform.PiecewiseAffine`
The transform object.
Raises
------
Warning
The given images do not have an explicit triangulation applied. A
Delaunay Triangulation will be computed and used for warping. This may
be suboptimal and cause warping artifacts.
"""
if not isinstance(shape, TriMesh) and isinstance(transform,
PiecewiseAffine):
warnings.warn('The given images do not have an explicit triangulation '
'applied. A Delaunay Triangulation will be computed '
'and used for warping. This may be suboptimal and cause '
'warping artifacts.')
[docs]def check_scales(scales):
r"""
Checks that the provided `scales` argument is either `int` or `float` or an
iterable of those. It makes sure that it returns a `list` of `scales`.
Parameters
----------
scales : `int` or `float` or `list/tuple` of those
The value to check.
Returns
-------
scales : `list` of `int` or `float`
The scales in a list.
Raises
------
ValueError
scales must be an int/float or a list/tuple of int/float
"""
if isinstance(scales, (int, float)):
return [scales]
elif len(scales) == 1 and isinstance(scales[0], (int, float)):
return list(scales)
elif len(scales) > 1:
return check_scales(scales[0]) + check_scales(scales[1:])
else:
raise ValueError("scales must be an int/float or a list/tuple of "
"int/float")
[docs]def check_multi_scale_param(n_scales, types, param_name, param):
r"""
General function for checking a parameter defined for multiple scales. It
raises an error if the parameter is not an iterable with the correct size and
correct types.
Parameters
----------
n_scales : `int`
The number of scales.
types : `tuple`
The `tuple` of variable types that the parameter is allowed to have.
param_name : `str`
The name of the parameter.
param : `types`
The parameter value.
Returns
-------
param : `list` of `types`
The list of values per scale.
Raises
------
ValueError
{param_name} must be in {types} or a list/tuple of {types} with the same
length as the number of scales
"""
error_msg = "{0} must be in {1} or a list/tuple of " \
"{1} with the same length as the number " \
"of scales".format(param_name, types)
# Could be a single value - or we have an error
if isinstance(param, types):
return [param] * n_scales
elif not isinstance(param, collections.Iterable):
raise ValueError(error_msg)
# Must be an iterable object
len_param = len(param)
isinstance_all_in_param = all(isinstance(p, types) for p in param)
if len_param == 1 and isinstance_all_in_param:
return list(param) * n_scales
elif len_param == n_scales and isinstance_all_in_param:
return list(param)
else:
raise ValueError(error_msg)
[docs]def check_callable(callables, n_scales):
r"""
Checks the callable type per level.
Parameters
----------
callables : `callable` or `list` of `callables`
The callable to be used per scale.
n_scales : `int`
The number of scales.
Returns
-------
callable_list : `list`
A `list` of callables.
Raises
------
ValueError
callables must be a callable or a list/tuple of callables with the same
length as the number of scales
"""
if callable(callables):
return [callables] * n_scales
elif len(callables) == 1 and np.alltrue([callable(f) for f in callables]):
return list(callables) * n_scales
elif len(callables) == n_scales and np.alltrue([callable(f)
for f in callables]):
return list(callables)
else:
raise ValueError("callables must be a callable or a list/tuple of "
"callables with the same length as the number "
"of scales")
[docs]def check_patch_shape(patch_shape, n_scales):
r"""
Function for checking a multi-scale `patch_shape` parameter value.
Parameters
----------
patch_shape : `list/tuple` of `int/float` or `list` of those
The patch shape per scale
n_scales : `int`
The number of scales.
Returns
-------
patch_shape : `list` of `list/tuple` of `int/float`
The list of patch shape per scale.
Raises
------
ValueError
patch_shape must be a list/tuple of int or a list/tuple of lit/tuple of
int/float with the same length as the number of scales
"""
if len(patch_shape) == 2 and isinstance(patch_shape[0], int):
return [patch_shape] * n_scales
elif len(patch_shape) == 1:
return check_patch_shape(patch_shape[0], 1)
elif len(patch_shape) == n_scales:
l1 = check_patch_shape(patch_shape[0], 1)
l2 = check_patch_shape(patch_shape[1:], n_scales-1)
return l1 + l2
else:
raise ValueError("patch_shape must be a list/tuple of int or a "
"list/tuple of lit/tuple of int/float with the "
"same length as the number of scales")
[docs]def check_max_components(max_components, n_scales, var_name):
r"""
Checks the maximum number of components per scale. It must be ``None`` or
`int` or `float` or a `list` of those containing ``1`` or ``{n_scales}``
elements.
Parameters
----------
max_components : ``None`` or `int` or `float` or a `list` of those
The value to check.
n_scales : `int`
The number of scales.
var_name : `str`
The name of the variable.
Returns
-------
max_components : `list` of ``None`` or `int` or `float`
The list of max components per scale.
Raises
------
ValueError
{var_name} must be None or an int > 0 or a 0 <= float <= 1 or a list of
those containing 1 or {n_scales} elements
"""
str_error = ("{} must be None or an int > 0 or a 0 <= float <= 1 or "
"a list of those containing 1 or {} elements").format(
var_name, n_scales)
if not isinstance(max_components, (list, tuple)):
max_components_list = [max_components] * n_scales
elif len(max_components) == 1:
max_components_list = [max_components[0]] * n_scales
elif len(max_components) == n_scales:
max_components_list = max_components
else:
raise ValueError(str_error)
for comp in max_components_list:
if comp is not None:
if not isinstance(comp, int):
if not isinstance(comp, float):
raise ValueError(str_error)
return max_components_list
[docs]def check_max_iters(max_iters, n_scales):
r"""
Function that checks the value of a `max_iters` parameter defined for
multiple scales. It must be `int` or `list` of `int`.
Parameters
----------
max_iters : `int` or `list` of `int`
The value to check.
n_scales : `int`
The number of scales.
Returns
-------
max_iters : `list` of `int`
The list of values per scale.
Raises
------
ValueError
max_iters can be integer, integer list containing 1 or {n_scales}
elements or None
"""
if type(max_iters) is int:
max_iters = [np.round(max_iters/n_scales)
for _ in range(n_scales)]
elif len(max_iters) == 1 and n_scales > 1:
max_iters = [np.round(max_iters[0]/n_scales)
for _ in range(n_scales)]
elif len(max_iters) != n_scales:
raise ValueError('max_iters can be integer, integer list '
'containing 1 or {} elements or '
'None'.format(n_scales))
return np.require(max_iters, dtype=np.int)
[docs]def check_sampling(sampling, n_scales):
r"""
Function that checks the value of a `sampling` parameter defined for
multiple scales. It must be `int` or `ndarray` or `list` of those.
Parameters
----------
sampling : `int` or `ndarray` or `list` of those
The value to check.
n_scales : `int`
The number of scales.
Returns
-------
sampling : `list` of `int` or `ndarray`
The list of values per scale.
Raises
------
ValueError
A sampling list can only contain 1 element or {n_scales} elements
ValueError
sampling can be an integer or ndarray, a integer or ndarray list
containing 1 or {n_scales} elements or None
"""
if (isinstance(sampling, (list, tuple)) and
np.alltrue([isinstance(s, (np.ndarray, np.int)) or sampling is None
for s in sampling])):
if len(sampling) == 1:
return sampling * n_scales
elif len(sampling) == n_scales:
return sampling
else:
raise ValueError('A sampling list can only '
'contain 1 element or {} '
'elements'.format(n_scales))
elif isinstance(sampling, (np.ndarray, np.int)) or sampling is None:
return [sampling] * n_scales
else:
raise ValueError('sampling can be an integer or ndarray, '
'a integer or ndarray list '
'containing 1 or {} elements or '
'None'.format(n_scales))
[docs]def set_models_components(models, n_components):
r"""
Function that sets the number of active components to a list of models.
Parameters
----------
models : `list` or `class`
The list of models per scale.
n_components : `int` or `float` or ``None`` or `list` of those
The number of components per model.
Raises
------
ValueError
n_components can be an integer or a float or None or a list containing 1
or {n_scales} of those
"""
if n_components is not None:
n_scales = len(models)
if type(n_components) is int or type(n_components) is float:
for am in models:
am.n_active_components = n_components
elif len(n_components) == 1 and n_scales > 1:
for am in models:
am.n_active_components = n_components[0]
elif len(n_components) == n_scales:
for am, n in zip(models, n_components):
am.n_active_components = n
else:
raise ValueError('n_components can be an integer or a float '
'or None or a list containing 1 or {} of '
'those'.format(n_scales))
[docs]def check_model(model, cls):
r"""
Function that checks whether the provided `class` object is a subclass of
the provided base `class`.
Parameters
----------
model : `class`
The object.
cls : `class`
The required base class.
Raises
------
ValueError
Model must be a {cls} instance.
"""
if not isinstance(model, cls):
raise ValueError('Model must be a {} instance.'.format(
name_of_callable(cls)))
[docs]def check_algorithm_cls(algorithm_cls, n_scales, base_algorithm_cls):
r"""
Function that checks whether the `list` of `class` objects defined per scale
are subclasses of the provided base `class`.
Parameters
----------
algorithm_cls : `class` or `list` of `class`
The list of objects per scale.
n_scales : `int`
The number of scales.
base_algorithm_cls : `class`
The required base class.
Raises
------
ValueError
algorithm_cls must be a subclass of {base_algorithm_cls} or a list/tuple
of {base_algorithm_cls} subclasses with the same length as the number of
scales {n_scales}
"""
if (isinstance(algorithm_cls, partial) and
base_algorithm_cls in algorithm_cls.func.mro()):
return [algorithm_cls] * n_scales
elif (isinstance(algorithm_cls, type) and
base_algorithm_cls in algorithm_cls.mro()):
return [algorithm_cls] * n_scales
elif len(algorithm_cls) == 1:
return check_algorithm_cls(algorithm_cls[0], n_scales,
base_algorithm_cls)
elif len(algorithm_cls) == n_scales:
return [check_algorithm_cls(a, 1, base_algorithm_cls)[0]
for a in algorithm_cls]
else:
raise ValueError("algorithm_cls must be a subclass of {} or a "
"list/tuple of {} subclasses with the same length "
"as the number of scales {}"
.format(base_algorithm_cls, base_algorithm_cls,
n_scales))
[docs]def check_graph(graph, graph_types, param_name, n_scales):
r"""
Checks the provided graph per pyramidal level. The graph must be a
subclass of `graph_types` or a `list` of those.
Parameters
----------
graph : `graph` or `list` of `graph` types
The graph argument to check.
graph_types : `graph` or `tuple` of `graphs`
The `tuple` of allowed graph types.
param_name : `str`
The name of the graph parameter.
n_scales : `int`
The number of pyramidal levels.
Returns
-------
graph : `list` of `graph` types
The graph per scale in a `list`.
Raises
------
ValueError
{param_name} must be a list of length equal to the number of scales.
ValueError
{param_name} must be a list of {graph_types_str}. {} given instead.
"""
# check if the provided graph is a list
if not isinstance(graph, list):
graphs = [graph] * n_scales
elif len(graph) == 1:
graphs = graph * n_scales
elif len(graph) == n_scales:
graphs = graph
else:
raise ValueError('{} must be a list of length equal to the number of '
'scales.'.format(param_name))
# check if the provided graph_types is a list
if not isinstance(graph_types, list):
graph_types = [graph_types]
# check each member of the graphs list
for g in graphs:
if g is not None:
if type(g) not in graph_types:
graph_types_str = ' or '.join(gt.__name__ for gt in graph_types)
raise ValueError('{} must be a list of {}. {} given '
'instead.'.format(param_name, graph_types_str,
type(g).__name__))
return graphs