_fisher_vector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """
  2. fisher_vector.py - Implementation of the Fisher vector encoding algorithm
  3. This module contains the source code for Fisher vector computation. The
  4. computation is separated into two distinct steps, which are called separately
  5. by the user, namely:
  6. learn_gmm: Used to estimate the GMM for all vectors/descriptors computed for
  7. all examples in the dataset (e.g. estimated using all the SIFT
  8. vectors computed for all images in the dataset, or at least a subset
  9. of this).
  10. fisher_vector: Used to compute the Fisher vector representation for a
  11. single set of descriptors/vector (e.g. the SIFT
  12. descriptors for a single image in your dataset, or
  13. perhaps a test image).
  14. Reference: Perronnin, F. and Dance, C. Fisher kernels on Visual Vocabularies
  15. for Image Categorization, IEEE Conference on Computer Vision and
  16. Pattern Recognition, 2007
  17. Origin Author: Dan Oneata (Author of the original implementation for the Fisher
  18. vector computation using scikit-learn and NumPy. Subsequently ported to
  19. scikit-image (here) by other authors.)
  20. """
  21. import numpy as np
  22. __doctest_requires__ = {("learn_gmm", "fisher_vector"): ["sklearn"]}
  23. class FisherVectorException(Exception):
  24. pass
  25. class DescriptorException(FisherVectorException):
  26. pass
  27. def learn_gmm(descriptors, *, n_modes=32, gm_args=None):
  28. """Estimate a Gaussian mixture model (GMM) given a set of descriptors and
  29. number of modes (i.e. Gaussians). This function is essentially a wrapper
  30. around the scikit-learn implementation of GMM, namely the
  31. :class:`sklearn.mixture.GaussianMixture` class.
  32. Due to the nature of the Fisher vector, the only enforced parameter of the
  33. underlying scikit-learn class is the covariance_type, which must be 'diag'.
  34. There is no simple way to know what value to use for `n_modes` a-priori.
  35. Typically, the value is usually one of ``{16, 32, 64, 128}``. One may train
  36. a few GMMs and choose the one that maximises the log probability of the
  37. GMM, or choose `n_modes` such that the downstream classifier trained on
  38. the resultant Fisher vectors has maximal performance.
  39. Parameters
  40. ----------
  41. descriptors : np.ndarray (N, M) or list [(N1, M), (N2, M), ...]
  42. List of NumPy arrays, or a single NumPy array, of the descriptors
  43. used to estimate the GMM. The reason a list of NumPy arrays is
  44. permissible is because often when using a Fisher vector encoding,
  45. descriptors/vectors are computed separately for each sample/image in
  46. the dataset, such as SIFT vectors for each image. If a list if passed
  47. in, then each element must be a NumPy array in which the number of
  48. rows may differ (e.g. different number of SIFT vector for each image),
  49. but the number of columns for each must be the same (i.e. the
  50. dimensionality must be the same).
  51. n_modes : int
  52. The number of modes/Gaussians to estimate during the GMM estimate.
  53. gm_args : dict
  54. Keyword arguments that can be passed into the underlying scikit-learn
  55. :class:`sklearn.mixture.GaussianMixture` class.
  56. Returns
  57. -------
  58. gmm : :class:`sklearn.mixture.GaussianMixture`
  59. The estimated GMM object, which contains the necessary parameters
  60. needed to compute the Fisher vector.
  61. References
  62. ----------
  63. .. [1] https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html
  64. Examples
  65. --------
  66. >>> from skimage.feature import fisher_vector
  67. >>> rng = np.random.Generator(np.random.PCG64())
  68. >>> sift_for_images = [rng.standard_normal((10, 128)) for _ in range(10)]
  69. >>> num_modes = 16
  70. >>> # Estimate 16-mode GMM with these synthetic SIFT vectors
  71. >>> gmm = learn_gmm(sift_for_images, n_modes=num_modes)
  72. """
  73. try:
  74. from sklearn.mixture import GaussianMixture
  75. except ImportError:
  76. raise ImportError(
  77. 'scikit-learn is not installed. Please ensure it is installed in '
  78. 'order to use the Fisher vector functionality.'
  79. )
  80. if not isinstance(descriptors, (list, np.ndarray)):
  81. raise DescriptorException(
  82. 'Please ensure descriptors are either a NumPy array, '
  83. 'or a list of NumPy arrays.'
  84. )
  85. d_mat_1 = descriptors[0]
  86. if isinstance(descriptors, list) and not isinstance(d_mat_1, np.ndarray):
  87. raise DescriptorException(
  88. 'Please ensure descriptors are a list of NumPy arrays.'
  89. )
  90. if isinstance(descriptors, list):
  91. expected_shape = descriptors[0].shape
  92. ranks = [len(e.shape) == len(expected_shape) for e in descriptors]
  93. if not all(ranks):
  94. raise DescriptorException(
  95. 'Please ensure all elements of your descriptor list ' 'are of rank 2.'
  96. )
  97. dims = [e.shape[1] == descriptors[0].shape[1] for e in descriptors]
  98. if not all(dims):
  99. raise DescriptorException(
  100. 'Please ensure all descriptors are of the same dimensionality.'
  101. )
  102. if not isinstance(n_modes, int) or n_modes <= 0:
  103. raise FisherVectorException('Please ensure n_modes is a positive integer.')
  104. if gm_args:
  105. has_cov_type = 'covariance_type' in gm_args
  106. cov_type_not_diag = gm_args['covariance_type'] != 'diag'
  107. if has_cov_type and cov_type_not_diag:
  108. raise FisherVectorException('Covariance type must be "diag".')
  109. if isinstance(descriptors, list):
  110. descriptors = np.vstack(descriptors)
  111. if gm_args:
  112. has_cov_type = 'covariance_type' in gm_args
  113. if has_cov_type:
  114. gmm = GaussianMixture(n_components=n_modes, **gm_args)
  115. else:
  116. gmm = GaussianMixture(
  117. n_components=n_modes, covariance_type='diag', **gm_args
  118. )
  119. else:
  120. gmm = GaussianMixture(n_components=n_modes, covariance_type='diag')
  121. gmm.fit(descriptors)
  122. return gmm
  123. def fisher_vector(descriptors, gmm, *, improved=False, alpha=0.5):
  124. """Compute the Fisher vector given some descriptors/vectors,
  125. and an associated estimated GMM.
  126. Parameters
  127. ----------
  128. descriptors : np.ndarray, shape=(n_descriptors, descriptor_length)
  129. NumPy array of the descriptors for which the Fisher vector
  130. representation is to be computed.
  131. gmm : :class:`sklearn.mixture.GaussianMixture`
  132. An estimated GMM object, which contains the necessary parameters needed
  133. to compute the Fisher vector.
  134. improved : bool, default=False
  135. Flag denoting whether to compute improved Fisher vectors or not.
  136. Improved Fisher vectors are L2 and power normalized. Power
  137. normalization is simply f(z) = sign(z) pow(abs(z), alpha) for some
  138. 0 <= alpha <= 1.
  139. alpha : float, default=0.5
  140. The parameter for the power normalization step. Ignored if
  141. improved=False.
  142. Returns
  143. -------
  144. fisher_vector : np.ndarray
  145. The computation Fisher vector, which is given by a concatenation of the
  146. gradients of a GMM with respect to its parameters (mixture weights,
  147. means, and covariance matrices). For D-dimensional input descriptors or
  148. vectors, and a K-mode GMM, the Fisher vector dimensionality will be
  149. 2KD + K. Thus, its dimensionality is invariant to the number of
  150. descriptors/vectors.
  151. References
  152. ----------
  153. .. [1] Perronnin, F. and Dance, C. Fisher kernels on Visual Vocabularies
  154. for Image Categorization, IEEE Conference on Computer Vision and
  155. Pattern Recognition, 2007
  156. .. [2] Perronnin, F. and Sanchez, J. and Mensink T. Improving the Fisher
  157. Kernel for Large-Scale Image Classification, ECCV, 2010
  158. Examples
  159. --------
  160. >>> from skimage.feature import fisher_vector, learn_gmm
  161. >>> sift_for_images = [np.random.random((10, 128)) for _ in range(10)]
  162. >>> num_modes = 16
  163. >>> # Estimate 16-mode GMM with these synthetic SIFT vectors
  164. >>> gmm = learn_gmm(sift_for_images, n_modes=num_modes)
  165. >>> test_image_descriptors = np.random.random((25, 128))
  166. >>> # Compute the Fisher vector
  167. >>> fv = fisher_vector(test_image_descriptors, gmm)
  168. """
  169. try:
  170. from sklearn.mixture import GaussianMixture
  171. except ImportError:
  172. raise ImportError(
  173. 'scikit-learn is not installed. Please ensure it is installed in '
  174. 'order to use the Fisher vector functionality.'
  175. )
  176. if not isinstance(descriptors, np.ndarray):
  177. raise DescriptorException('Please ensure descriptors is a NumPy array.')
  178. if not isinstance(gmm, GaussianMixture):
  179. raise FisherVectorException(
  180. 'Please ensure gmm is a sklearn.mixture.GaussianMixture object.'
  181. )
  182. if improved and not isinstance(alpha, float):
  183. raise FisherVectorException(
  184. 'Please ensure that the alpha parameter is a float.'
  185. )
  186. num_descriptors = len(descriptors)
  187. mixture_weights = gmm.weights_
  188. means = gmm.means_
  189. covariances = gmm.covariances_
  190. posterior_probabilities = gmm.predict_proba(descriptors)
  191. # Statistics necessary to compute GMM gradients wrt its parameters
  192. pp_sum = posterior_probabilities.mean(axis=0, keepdims=True).T
  193. pp_x = posterior_probabilities.T.dot(descriptors) / num_descriptors
  194. pp_x_2 = posterior_probabilities.T.dot(np.power(descriptors, 2)) / num_descriptors
  195. # Compute GMM gradients wrt its parameters
  196. d_pi = pp_sum.squeeze() - mixture_weights
  197. d_mu = pp_x - pp_sum * means
  198. d_sigma_t1 = pp_sum * np.power(means, 2)
  199. d_sigma_t2 = pp_sum * covariances
  200. d_sigma_t3 = 2 * pp_x * means
  201. d_sigma = -pp_x_2 - d_sigma_t1 + d_sigma_t2 + d_sigma_t3
  202. # Apply analytical diagonal normalization
  203. sqrt_mixture_weights = np.sqrt(mixture_weights)
  204. d_pi /= sqrt_mixture_weights
  205. d_mu /= sqrt_mixture_weights[:, np.newaxis] * np.sqrt(covariances)
  206. d_sigma /= np.sqrt(2) * sqrt_mixture_weights[:, np.newaxis] * covariances
  207. # Concatenate GMM gradients to form Fisher vector representation
  208. fisher_vector = np.hstack((d_pi, d_mu.ravel(), d_sigma.ravel()))
  209. if improved:
  210. fisher_vector = np.sign(fisher_vector) * np.power(np.abs(fisher_vector), alpha)
  211. fisher_vector = fisher_vector / np.linalg.norm(fisher_vector)
  212. return fisher_vector