trainable_segmentation.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from skimage.feature import multiscale_basic_features
  2. try:
  3. from sklearn.exceptions import NotFittedError
  4. from sklearn.ensemble import RandomForestClassifier
  5. has_sklearn = True
  6. except ImportError:
  7. has_sklearn = False
  8. class NotFittedError(Exception):
  9. pass
  10. class TrainableSegmenter:
  11. """Estimator for classifying pixels.
  12. Parameters
  13. ----------
  14. clf : classifier object, optional
  15. classifier object, exposing a ``fit`` and a ``predict`` method as in
  16. scikit-learn's API, for example an instance of
  17. ``RandomForestClassifier`` or ``LogisticRegression`` classifier.
  18. features_func : function, optional
  19. function computing features on all pixels of the image, to be passed
  20. to the classifier. The output should be of shape
  21. ``(m_features, *labels.shape)``. If None,
  22. :func:`skimage.feature.multiscale_basic_features` is used.
  23. Methods
  24. -------
  25. compute_features
  26. fit
  27. predict
  28. """
  29. def __init__(self, clf=None, features_func=None):
  30. if clf is None:
  31. if has_sklearn:
  32. self.clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
  33. else:
  34. raise ImportError(
  35. "Please install scikit-learn or pass a classifier instance"
  36. "to TrainableSegmenter."
  37. )
  38. else:
  39. self.clf = clf
  40. self.features_func = features_func
  41. def compute_features(self, image):
  42. if self.features_func is None:
  43. self.features_func = multiscale_basic_features
  44. self.features = self.features_func(image)
  45. def fit(self, image, labels):
  46. """Train classifier using partially labeled (annotated) image.
  47. Parameters
  48. ----------
  49. image : ndarray
  50. Input image, which can be grayscale or multichannel, and must have a
  51. number of dimensions compatible with ``self.features_func``.
  52. labels : ndarray of ints
  53. Labeled array of shape compatible with ``image`` (same shape for a
  54. single-channel image). Labels >= 1 correspond to the training set and
  55. label 0 to unlabeled pixels to be segmented.
  56. """
  57. self.compute_features(image)
  58. fit_segmenter(labels, self.features, self.clf)
  59. def predict(self, image):
  60. """Segment new image using trained internal classifier.
  61. Parameters
  62. ----------
  63. image : ndarray
  64. Input image, which can be grayscale or multichannel, and must have a
  65. number of dimensions compatible with ``self.features_func``.
  66. Raises
  67. ------
  68. NotFittedError if ``self.clf`` has not been fitted yet (use ``self.fit``).
  69. """
  70. if self.features_func is None:
  71. self.features_func = multiscale_basic_features
  72. features = self.features_func(image)
  73. return predict_segmenter(features, self.clf)
  74. def fit_segmenter(labels, features, clf):
  75. """Segmentation using labeled parts of the image and a classifier.
  76. Parameters
  77. ----------
  78. labels : ndarray of ints
  79. Image of labels. Labels >= 1 correspond to the training set and
  80. label 0 to unlabeled pixels to be segmented.
  81. features : ndarray
  82. Array of features, with the first dimension corresponding to the number
  83. of features, and the other dimensions correspond to ``labels.shape``.
  84. clf : classifier object
  85. classifier object, exposing a ``fit`` and a ``predict`` method as in
  86. scikit-learn's API, for example an instance of
  87. ``RandomForestClassifier`` or ``LogisticRegression`` classifier.
  88. Returns
  89. -------
  90. clf : classifier object
  91. classifier trained on ``labels``
  92. Raises
  93. ------
  94. NotFittedError if ``self.clf`` has not been fitted yet (use ``self.fit``).
  95. """
  96. mask = labels > 0
  97. training_data = features[mask]
  98. training_labels = labels[mask].ravel()
  99. clf.fit(training_data, training_labels)
  100. return clf
  101. def predict_segmenter(features, clf):
  102. """Segmentation of images using a pretrained classifier.
  103. Parameters
  104. ----------
  105. features : ndarray
  106. Array of features, with the last dimension corresponding to the number
  107. of features, and the other dimensions are compatible with the shape of
  108. the image to segment, or a flattened image.
  109. clf : classifier object
  110. trained classifier object, exposing a ``predict`` method as in
  111. scikit-learn's API, for example an instance of
  112. ``RandomForestClassifier`` or ``LogisticRegression`` classifier. The
  113. classifier must be already trained, for example with
  114. :func:`skimage.future.fit_segmenter`.
  115. Returns
  116. -------
  117. output : ndarray
  118. Labeled array, built from the prediction of the classifier.
  119. """
  120. sh = features.shape
  121. if features.ndim > 2:
  122. features = features.reshape((-1, sh[-1]))
  123. try:
  124. predicted_labels = clf.predict(features)
  125. except NotFittedError:
  126. raise NotFittedError(
  127. "You must train the classifier `clf` first"
  128. "for example with the `fit_segmenter` function."
  129. )
  130. except ValueError as err:
  131. if err.args and 'x must consist of vectors of length' in err.args[0]:
  132. raise ValueError(
  133. err.args[0]
  134. + '\n'
  135. + "Maybe you did not use the same type of features for training the classifier."
  136. )
  137. else:
  138. raise err
  139. output = predicted_labels.reshape(sh[:-1])
  140. return output