fit.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493
  1. import inspect
  2. import math
  3. from typing import Protocol, runtime_checkable, Self
  4. from warnings import warn, catch_warnings
  5. import numpy as np
  6. from numpy.linalg import inv
  7. from scipy import optimize, spatial
  8. from .._shared.utils import (
  9. _deprecate_estimate,
  10. FailedEstimation,
  11. deprecate_parameter,
  12. deprecate_func,
  13. DEPRECATED,
  14. )
  15. _EPSILON = np.spacing(1)
  16. def _check_data_dim(data, dim):
  17. if data.ndim != 2 or data.shape[1] != dim:
  18. raise ValueError(f"Input data must have shape (N, {dim}).")
  19. def _check_data_atleast_2D(data):
  20. if data.ndim < 2 or data.shape[1] < 2:
  21. raise ValueError('Input data must be at least 2D.')
  22. @runtime_checkable
  23. class RansacModelProtocol(Protocol):
  24. """Protocol for `ransac` model class."""
  25. @classmethod
  26. def from_estimate(cls, *data): ...
  27. def residuals(self, *data): ...
  28. _PARAMS_DEP_START = '0.26'
  29. _PARAMS_DEP_STOP = '2.2'
  30. class BaseModel:
  31. def __init_subclass__(self):
  32. warn(
  33. f'`BaseModel` deprecated since version {_PARAMS_DEP_START} and '
  34. f'will be removed in version {_PARAMS_DEP_STOP}',
  35. category=FutureWarning,
  36. stacklevel=2,
  37. )
  38. class _BaseModel:
  39. """Implement common methods for model classes.
  40. This class can be removed when we expire deprecations of ``estimate``
  41. method, and `params` arguments to ``predict*`` methods.
  42. Note that each inheriting class will need to implement
  43. ``_params2init_values``, that breaks up the ``params`` vector into separate
  44. components comprising the arguments to the function ``__init__``, and
  45. checks the resulting input arguments for validity.
  46. """
  47. @classmethod
  48. def from_estimate(cls, data) -> Self | FailedEstimation:
  49. # In order to defer to the ``_estimate`` method, we first need to
  50. # create an empty not-initialized instance, that we can override by
  51. # executing the ``_estimate`` method. This relies on the assumption
  52. # that `_estimate` can work with an uninitialized instance. This
  53. # assumption only need hold until we can expire the deprecation of the
  54. # `estimate` method, at which point we can move the estimation logic
  55. # from the ``_estimate`` methods, to the respective ``from_estimate``
  56. # class methods.
  57. with catch_warnings(action='ignore'):
  58. tf = cls()
  59. msg = tf._estimate(data, warn_only=False)
  60. return tf if msg is None else FailedEstimation(f'{cls.__name__}: {msg}')
  61. def _get_init_values(self, params):
  62. if params is None or params is DEPRECATED:
  63. if getattr(self, self._init_args[0]) is None:
  64. # Until the deprecation of no-argument initialization expires,
  65. # it is easy to create a not-initialized model, evidenced by
  66. # None values of the init attributes.
  67. cls_name = type(self).__name__
  68. raise ValueError(
  69. '`params` argument must be specified when '
  70. 'applied to model initialized with '
  71. f'``{cls_name}()``; Consider creating new '
  72. f'{cls_name} with suitable input arguments, '
  73. f'or by using ``{cls_name}.from_estimate``.'
  74. )
  75. return [getattr(self, a) for a in self._init_args]
  76. return self._params2init_values(params)
  77. def _warn_or_msg(msg, warn_only=True):
  78. """If `warn_only`, warn with `msg`, return ``None``, else return `msg`
  79. For `from_estimate` API, we want to return a ``FailedEstimation`` for these
  80. estimation failures, which we do by setting ``warn_only=False``, and
  81. passing back the `msg` from the ``_estimation`` method via this function.
  82. For the deprecated ``estimate`` API, we want to warn (``warn_only=True``),
  83. and return an incomplete transform. The ``None`` return value indicates
  84. the estimation has kind-of succeeded, for back compatibility.
  85. """
  86. if not warn_only:
  87. return msg
  88. warn(msg, category=RuntimeWarning, stacklevel=5)
  89. return None
  90. def _deprecate_no_args(cls):
  91. """Class decorator to allow, deprecate no input arguments to ``__init__``.
  92. Makes a new ``__init__`` method, that a) will allow option of passing no
  93. arguments, and b) when used thus, raises a deprecation warning. Otherwise
  94. defers to an assumed-existing ``_args_init`` instance method to deal with
  95. input arguments. If there are no parameters, set desired parameters to
  96. None, to signal uninitialized object.
  97. At the end of deprecation we can drop this decorator, and rename
  98. ``_args_init`` to ``__init__``.
  99. """
  100. args_init_sig = inspect.signature(cls._args_init)
  101. cls._init_args = [k for k in args_init_sig.parameters if k != 'self']
  102. def init(self, *args, **kwargs):
  103. if len(args) or len(kwargs):
  104. self._args_init(*args, **kwargs)
  105. return
  106. warn(
  107. f'Calling ``{cls.__name__}()`` (without arguments) has been '
  108. f'deprecated since version {_PARAMS_DEP_START} and will be '
  109. f'removed in version {_PARAMS_DEP_STOP}; see help for '
  110. f'``{cls.__name__}``.',
  111. category=FutureWarning,
  112. stacklevel=2,
  113. )
  114. # Blank initialization.
  115. for k in cls._init_args:
  116. setattr(self, k, None)
  117. init.__signature__ = args_init_sig
  118. cls.__init__ = init
  119. return cls
  120. def _deprecate_model_params(func):
  121. """Deprecate `params` argument of various model methods."""
  122. func = deprecate_parameter(
  123. 'params',
  124. start_version=_PARAMS_DEP_START,
  125. stop_version=_PARAMS_DEP_STOP,
  126. modify_docstring=False,
  127. )(func)
  128. func.__doc__ = func.__doc__.replace('{{ start_version }}', _PARAMS_DEP_START)
  129. return func
  130. @_deprecate_no_args
  131. class LineModelND(_BaseModel):
  132. """Total least squares estimator for N-dimensional lines.
  133. In contrast to ordinary least squares line estimation, this estimator
  134. minimizes the orthogonal distances of points to the estimated line.
  135. Lines are defined by a point (origin) and a unit vector (direction)
  136. according to the following vector equation::
  137. X = origin + lambda * direction
  138. Parameters
  139. ----------
  140. origin : array-like, shape (N,)
  141. Coordinates of line origin in N dimensions.
  142. direction : array-like, shape (N,)
  143. Vector giving line direction.
  144. Raises
  145. ------
  146. ValueError
  147. If length of `origin` and `direction` differ.
  148. Examples
  149. --------
  150. >>> x = np.linspace(1, 2, 25)
  151. >>> y = 1.5 * x + 3
  152. >>> lm = LineModelND.from_estimate(np.stack([x, y], axis=-1))
  153. >>> lm.origin
  154. array([1.5 , 5.25])
  155. >>> lm.direction # doctest: +FLOAT_CMP
  156. array([0.5547 , 0.83205])
  157. >>> res = lm.residuals(np.stack([x, y], axis=-1))
  158. >>> np.abs(np.round(res, 9))
  159. array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  160. 0., 0., 0., 0., 0., 0., 0., 0.])
  161. >>> np.round(lm.predict_y(x[:5]), 3)
  162. array([4.5 , 4.562, 4.625, 4.688, 4.75 ])
  163. >>> np.round(lm.predict_x(y[:5]), 3)
  164. array([1. , 1.042, 1.083, 1.125, 1.167])
  165. """
  166. def _args_init(self, origin, direction):
  167. """Initialize ``LineModelND`` instance.
  168. Parameters
  169. ----------
  170. origin : array-like, shape (N,)
  171. Coordinates of line origin in N dimensions.
  172. direction : array-like, shape (N,)
  173. Vector giving line direction.
  174. """
  175. self.origin, self.direction = self._check_init_values(origin, direction)
  176. def _check_init_values(self, origin, direction):
  177. origin, direction = (np.array(v) for v in (origin, direction))
  178. if len(origin) != len(direction):
  179. raise ValueError('Direction vector should be same length as origin point.')
  180. return origin, direction
  181. def _params2init_values(self, params):
  182. if len(params) != 2:
  183. raise ValueError('Input `params` should be length 2')
  184. return self._check_init_values(*params)
  185. @property
  186. @deprecate_func(
  187. deprecated_version=_PARAMS_DEP_START,
  188. removed_version=_PARAMS_DEP_STOP,
  189. hint='`params` attribute deprecated; use ``origin, direction`` attributes instead',
  190. )
  191. def params(self):
  192. """Return model attributes as ``origin, direction`` tuple."""
  193. return self.origin, self.direction
  194. @classmethod
  195. def from_estimate(cls, data):
  196. """Estimate line model from data.
  197. This minimizes the sum of shortest (orthogonal) distances
  198. from the given data points to the estimated line.
  199. Parameters
  200. ----------
  201. data : (N, dim) array
  202. N points in a space of dimensionality dim >= 2.
  203. Returns
  204. -------
  205. model : Self or `~.FailedEstimation`
  206. An instance of the line model if the estimation succeeded.
  207. Otherwise, we return a special ``FailedEstimation`` object to
  208. signal a failed estimation. Testing the truth value of the failed
  209. estimation object will return ``False``. E.g.
  210. .. code-block:: python
  211. model = LineModelND.from_estimate(...)
  212. if not model:
  213. raise RuntimeError(f"Failed estimation: {model}")
  214. """
  215. return super().from_estimate(data)
  216. def _estimate(self, data, warn_only=True):
  217. _check_data_atleast_2D(data)
  218. origin = data.mean(axis=0)
  219. data = data - origin
  220. if data.shape[0] == 2: # well determined
  221. direction = data[1] - data[0]
  222. norm = np.linalg.norm(direction)
  223. if norm != 0: # this should not happen to be norm 0
  224. direction /= norm
  225. elif data.shape[0] > 2: # over-determined
  226. # Note: with full_matrices=1 Python dies with joblib parallel_for.
  227. _, _, v = np.linalg.svd(data, full_matrices=False)
  228. direction = v[0]
  229. else: # under-determined
  230. return 'estimate under-determined'
  231. self.origin = origin
  232. self.direction = direction
  233. return None
  234. @_deprecate_model_params
  235. def residuals(self, data, params=DEPRECATED):
  236. """Determine residuals of data to model.
  237. For each point, the shortest (orthogonal) distance to the line is
  238. returned. It is obtained by projecting the data onto the line.
  239. Parameters
  240. ----------
  241. data : (N, dim) array
  242. N points in a space of dimension dim.
  243. Returns
  244. -------
  245. residuals : (N,) array
  246. Residual for each data point.
  247. Other parameters
  248. ----------------
  249. params : `~.DEPRECATED`, optional
  250. Optional custom parameter set in the form (`origin`, `direction`).
  251. .. deprecated:: {{ start_version }}
  252. """
  253. _check_data_atleast_2D(data)
  254. origin, direction = self._get_init_values(params)
  255. if len(origin) != data.shape[1]:
  256. raise ValueError(
  257. f'`origin` is {len(origin)}D, but `data` is {data.shape[1]}D'
  258. )
  259. res = (data - origin) - ((data - origin) @ direction)[
  260. ..., np.newaxis
  261. ] * direction
  262. return np.linalg.norm(res, axis=1)
  263. @_deprecate_model_params
  264. def predict(self, x, axis=0, params=DEPRECATED):
  265. """Predict intersection of line model with orthogonal hyperplane.
  266. Parameters
  267. ----------
  268. x : (n, 1) array
  269. Coordinates along an axis.
  270. axis : int
  271. Axis orthogonal to the hyperplane intersecting the line.
  272. Returns
  273. -------
  274. data : (n, m) array
  275. Predicted coordinates.
  276. Other parameters
  277. ----------------
  278. params : `~.DEPRECATED`, optional
  279. Optional custom parameter set in the form (`origin`, `direction`).
  280. .. deprecated:: {{ start_version }}
  281. Raises
  282. ------
  283. ValueError
  284. If the line is parallel to the given axis.
  285. """
  286. origin, direction = self._get_init_values(params)
  287. if direction[axis] == 0:
  288. # line parallel to axis
  289. raise ValueError(f'Line parallel to axis {axis}')
  290. l = (x - origin[axis]) / direction[axis]
  291. data = origin + l[..., np.newaxis] * direction
  292. return data
  293. @_deprecate_model_params
  294. def predict_x(self, y, params=DEPRECATED):
  295. """Predict x-coordinates for 2D lines using the estimated model.
  296. Alias for::
  297. predict(y, axis=1)[:, 0]
  298. Parameters
  299. ----------
  300. y : array
  301. y-coordinates.
  302. Returns
  303. -------
  304. x : array
  305. Predicted x-coordinates.
  306. Other parameters
  307. ----------------
  308. params : `~.DEPRECATED`, optional
  309. Optional custom parameter set in the form (`origin`, `direction`).
  310. .. deprecated:: {{ start_version }}
  311. """
  312. # Avoid triggering deprecationwarning in predict.
  313. tf = (
  314. self
  315. if (params is None or params is DEPRECATED)
  316. else type(self)(*self._params2init_values(params))
  317. )
  318. x = tf.predict(y, axis=1)[:, 0]
  319. return x
  320. @_deprecate_model_params
  321. def predict_y(self, x, params=DEPRECATED):
  322. """Predict y-coordinates for 2D lines using the estimated model.
  323. Alias for::
  324. predict(x, axis=0)[:, 1]
  325. Parameters
  326. ----------
  327. x : array
  328. x-coordinates.
  329. Returns
  330. -------
  331. y : array
  332. Predicted y-coordinates.
  333. Other parameters
  334. ----------------
  335. params : `~.DEPRECATED`, optional
  336. Optional custom parameter set in the form (`origin`, `direction`).
  337. .. deprecated:: {{ start_version }}
  338. """
  339. # Avoid triggering deprecationwarning in predict.
  340. tf = (
  341. self
  342. if (params is None or params is DEPRECATED)
  343. else type(self)(*self._params2init_values(params))
  344. )
  345. y = tf.predict(x, axis=0)[:, 1]
  346. return y
  347. @_deprecate_estimate
  348. def estimate(self, data):
  349. """Estimate line model from data.
  350. This minimizes the sum of shortest (orthogonal) distances
  351. from the given data points to the estimated line.
  352. Parameters
  353. ----------
  354. data : (N, dim) array
  355. N points in a space of dimensionality ``dim >= 2``.
  356. Returns
  357. -------
  358. success : bool
  359. True, if model estimation succeeds.
  360. """
  361. return self._estimate(data) is None
  362. @_deprecate_no_args
  363. class CircleModel(_BaseModel):
  364. """Total least squares estimator for 2D circles.
  365. The functional model of the circle is::
  366. r**2 = (x - xc)**2 + (y - yc)**2
  367. This estimator minimizes the squared distances from all points to the
  368. circle::
  369. min{ sum((r - sqrt((x_i - xc)**2 + (y_i - yc)**2))**2) }
  370. A minimum number of 3 points is required to solve for the parameters.
  371. Parameters
  372. ----------
  373. center : array-like, shape (2,)
  374. Coordinates of circle center.
  375. radius : float
  376. Circle radius.
  377. Notes
  378. -----
  379. The estimation is carried out using a 2D version of the spherical
  380. estimation given in [1]_.
  381. References
  382. ----------
  383. .. [1] Jekel, Charles F. Obtaining non-linear orthotropic material models
  384. for pvc-coated polyester via inverse bubble inflation.
  385. Thesis (MEng), Stellenbosch University, 2016. Appendix A, pp. 83-87.
  386. https://hdl.handle.net/10019.1/98627
  387. Raises
  388. ------
  389. ValueError
  390. If `center` does not have length 2.
  391. Examples
  392. --------
  393. >>> t = np.linspace(0, 2 * np.pi, 25)
  394. >>> xy = CircleModel((2, 3), 4).predict_xy(t)
  395. >>> model = CircleModel.from_estimate(xy)
  396. >>> model.center
  397. array([2., 3.])
  398. >>> model.radius
  399. 4.0
  400. >>> res = model.residuals(xy)
  401. >>> np.abs(np.round(res, 9))
  402. array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  403. 0., 0., 0., 0., 0., 0., 0., 0.])
  404. The estimation can fail when — for example — all the input or output
  405. points are the same. If this happens, you will get a transform that is not
  406. "truthy" - meaning that ``bool(tform)`` is ``False``:
  407. >>> # A successfully estimated model is truthy:
  408. >>> if model:
  409. ... print("Estimation succeeded.")
  410. Estimation succeeded.
  411. >>> # Not so for a degenerate model with identical points.
  412. >>> bad_data = np.ones((4, 2))
  413. >>> bad_model = CircleModel.from_estimate(bad_data)
  414. >>> if not bad_model:
  415. ... print("Estimation failed.")
  416. Estimation failed.
  417. Trying to use this failed estimation transform result will give a suitable
  418. error:
  419. >>> bad_model.residuals(xy) # doctest: +IGNORE_EXCEPTION_DETAIL
  420. Traceback (most recent call last):
  421. ...
  422. FailedEstimationAccessError: No attribute "residuals" for failed estimation ...
  423. """
  424. def _args_init(self, center, radius):
  425. """Initialize CircleModel instance.
  426. Parameters
  427. ----------
  428. center : array-like, shape (2,)
  429. Coordinates of circle center.
  430. radius : float
  431. Circle radius.
  432. """
  433. self.center, self.radius = self._check_init_values(center, radius)
  434. def _check_init_values(self, center, radius):
  435. center = np.array(center)
  436. if not len(center) == 2:
  437. raise ValueError('Center coordinates should be length 2')
  438. return center, radius
  439. def _params2init_values(self, params):
  440. params = np.array(params)
  441. if len(params) != 3:
  442. raise ValueError('Input `params` should be length 3')
  443. return self._check_init_values(params[:2], params[2])
  444. @property
  445. @deprecate_func(
  446. deprecated_version=_PARAMS_DEP_START,
  447. removed_version=_PARAMS_DEP_STOP,
  448. hint='`params` attribute deprecated; use `center, radius` attributes instead',
  449. )
  450. def params(self):
  451. """Return model attributes ``center, radius`` as 1D array."""
  452. return np.r_[self.center, self.radius]
  453. @classmethod
  454. def from_estimate(cls, data):
  455. """Estimate circle model from data using total least squares.
  456. Parameters
  457. ----------
  458. data : (N, 2) array
  459. N points with ``(x, y)`` coordinates, respectively.
  460. Returns
  461. -------
  462. model : Self or `~.FailedEstimation`
  463. An instance of the circle model if the estimation succeeded.
  464. Otherwise, we return a special ``FailedEstimation`` object to
  465. signal a failed estimation. Testing the truth value of the failed
  466. estimation object will return ``False``. E.g.
  467. .. code-block:: python
  468. model = CircleModel.from_estimate(...)
  469. if not model:
  470. raise RuntimeError(f"Failed estimation: {model}")
  471. """
  472. return super().from_estimate(data)
  473. def _estimate(self, data, warn_only=True):
  474. _check_data_dim(data, dim=2)
  475. # to prevent integer overflow, cast data to float, if it isn't already
  476. float_type = np.promote_types(data.dtype, np.float32)
  477. data = data.astype(float_type, copy=False)
  478. # normalize value range to avoid misfitting due to numeric errors if
  479. # the relative distanceses are small compared to absolute distances
  480. origin = data.mean(axis=0)
  481. data = data - origin
  482. scale = data.std()
  483. if scale < np.finfo(float_type).tiny:
  484. return _warn_or_msg(
  485. "Standard deviation of data is too small to estimate "
  486. "circle with meaningful precision.",
  487. warn_only=warn_only,
  488. )
  489. data /= scale
  490. # Adapted from a spherical estimator covered in a blog post by Charles
  491. # Jeckel (see also reference 1 above):
  492. # https://jekel.me/2015/Least-Squares-Sphere-Fit/
  493. A = np.append(data * 2, np.ones((data.shape[0], 1), dtype=float_type), axis=1)
  494. f = np.sum(data**2, axis=1)
  495. C, _, rank, _ = np.linalg.lstsq(A, f, rcond=None)
  496. if rank != 3:
  497. return _warn_or_msg(
  498. "Input does not contain enough significant data points.",
  499. warn_only=warn_only,
  500. )
  501. center = C[0:2]
  502. distances = spatial.minkowski_distance(center, data)
  503. r = np.sqrt(np.mean(distances**2))
  504. # Revert normalization and set init params.
  505. self.center = center * scale + origin
  506. self.radius = r * scale
  507. return None
  508. def residuals(self, data):
  509. """Determine residuals of data to model.
  510. For each point the shortest distance to the circle is returned.
  511. Parameters
  512. ----------
  513. data : (N, 2) array
  514. N points with ``(x, y)`` coordinates, respectively.
  515. Returns
  516. -------
  517. residuals : (N,) array
  518. Residual for each data point.
  519. """
  520. _check_data_dim(data, dim=2)
  521. xc, yc = self.center
  522. r = self.radius
  523. x = data[:, 0]
  524. y = data[:, 1]
  525. return r - np.sqrt((x - xc) ** 2 + (y - yc) ** 2)
  526. @_deprecate_model_params
  527. def predict_xy(self, t, params=DEPRECATED):
  528. """Predict x- and y-coordinates using the estimated model.
  529. Parameters
  530. ----------
  531. t : array-like
  532. Angles in circle in radians. Angles start to count from positive
  533. x-axis to positive y-axis in a right-handed system.
  534. Returns
  535. -------
  536. xy : (..., 2) array
  537. Predicted x- and y-coordinates.
  538. Other parameters
  539. ----------------
  540. params : `~.DEPRECATED`, optional
  541. Optional parameters ``xc``, ``yc``, `radius`.
  542. .. deprecated:: {{ start_version }}
  543. """
  544. t = np.asanyarray(t)
  545. (xc, yc), r = self._get_init_values(params)
  546. x = xc + r * np.cos(t)
  547. y = yc + r * np.sin(t)
  548. return np.concatenate((x[..., None], y[..., None]), axis=t.ndim)
  549. @_deprecate_estimate
  550. def estimate(self, data):
  551. """Estimate circle model from data using total least squares.
  552. Parameters
  553. ----------
  554. data : (N, 2) array
  555. N points with ``(x, y)`` coordinates, respectively.
  556. Returns
  557. -------
  558. success : bool
  559. True, if model estimation succeeds.
  560. """
  561. return self._estimate(data) is None
  562. @_deprecate_no_args
  563. class EllipseModel(_BaseModel):
  564. """Total least squares estimator for 2D ellipses.
  565. The functional model of the ellipse is::
  566. xt = xc + a*cos(theta)*cos(t) - b*sin(theta)*sin(t)
  567. yt = yc + a*sin(theta)*cos(t) + b*cos(theta)*sin(t)
  568. d = sqrt((x - xt)**2 + (y - yt)**2)
  569. where ``(xt, yt)`` is the closest point on the ellipse to ``(x, y)``. Thus
  570. d is the shortest distance from the point to the ellipse.
  571. The estimator is based on a least squares minimization. The optimal
  572. solution is computed directly, no iterations are required. This leads
  573. to a simple, stable and robust fitting method.
  574. Parameters
  575. ----------
  576. center : array-like, shape (2,)
  577. Coordinates of ellipse center.
  578. axis_lengths : array-like, shape (2,)
  579. Length of first axis and length of second axis. Call these ``a`` and
  580. ``b``.
  581. theta : float
  582. Angle of first axis.
  583. Raises
  584. ------
  585. ValueError
  586. If `center` does not have length 2.
  587. Examples
  588. --------
  589. >>> em = EllipseModel((10, 15), (8, 4), np.deg2rad(30))
  590. >>> xy = em.predict_xy(np.linspace(0, 2 * np.pi, 25))
  591. >>> ellipse = EllipseModel.from_estimate(xy)
  592. >>> ellipse.center
  593. array([10., 15.])
  594. >>> ellipse.axis_lengths
  595. array([8., 4.])
  596. >>> round(ellipse.theta, 2)
  597. 0.52
  598. >>> np.round(abs(ellipse.residuals(xy)), 5)
  599. array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  600. 0., 0., 0., 0., 0., 0., 0., 0.])
  601. The estimation can fail when — for example — all the input or output
  602. points are the same. If this happens, you will get an ellipse model for
  603. which ``bool(model)`` is ``False``:
  604. >>> # A successfully estimated model is truthy:
  605. >>> if ellipse:
  606. ... print("Estimation succeeded.")
  607. Estimation succeeded.
  608. >>> # Not so for a degenerate model with identical points.
  609. >>> bad_data = np.ones((4, 2))
  610. >>> bad_ellipse = EllipseModel.from_estimate(bad_data)
  611. >>> if not bad_ellipse:
  612. ... print("Estimation failed.")
  613. Estimation failed.
  614. Trying to use this failed estimation transform result will give a suitable
  615. error:
  616. >>> bad_ellipse.residuals(xy) # doctest: +IGNORE_EXCEPTION_DETAIL
  617. Traceback (most recent call last):
  618. ...
  619. FailedEstimationAccessError: No attribute "residuals" for failed estimation ...
  620. """
  621. def _args_init(self, center, axis_lengths, theta):
  622. """Initialize ``EllipseModel`` instance.
  623. Parameters
  624. ----------
  625. center : array-like, shape (2,)
  626. Coordinates of ellipse center.
  627. axis_lengths : array-like, shape (2,)
  628. Length of first axis and length of second axis. Call these ``a``
  629. and ``b``.
  630. theta : float
  631. Angle of first axis.
  632. """
  633. self.center, self.axis_lengths, self.theta = self._check_init_values(
  634. center, axis_lengths, theta
  635. )
  636. def _check_init_values(self, center, axis_lengths, theta):
  637. center, axis_lengths = [np.array(v) for v in (center, axis_lengths)]
  638. if not len(center) == 2:
  639. raise ValueError('Center coordinates should be length 2')
  640. if not len(axis_lengths) == 2:
  641. raise ValueError('Axis lengths should be length 2')
  642. return center, axis_lengths, theta
  643. def _params2init_values(self, params):
  644. params = np.array(params)
  645. if len(params) != 5:
  646. raise ValueError('Input `params` should be length 5')
  647. return self._check_init_values(params[:2], params[2:4], params[4])
  648. @property
  649. @deprecate_func(
  650. deprecated_version=_PARAMS_DEP_START,
  651. removed_version=_PARAMS_DEP_STOP,
  652. hint='`params` attribute deprecated; use `center, axis_lengths, theta` attributes instead',
  653. )
  654. def params(self):
  655. """Return model attributes ``center, axis_lengths, theta`` as 1D array."""
  656. return np.r_[self.center, self.axis_lengths, self.theta]
  657. @classmethod
  658. def from_estimate(cls, data):
  659. """Estimate ellipse model from data using total least squares.
  660. Parameters
  661. ----------
  662. data : (N, 2) array
  663. N points with ``(x, y)`` coordinates, respectively.
  664. Returns
  665. -------
  666. model : Self or `~.FailedEstimation`
  667. An instance of the ellipse model if the estimation succeeded.
  668. Otherwise, we return a special ``FailedEstimation`` object to
  669. signal a failed estimation. Testing the truth value of the failed
  670. estimation object will return ``False``. E.g.
  671. .. code-block:: python
  672. model = EllipseModel.from_estimate(...)
  673. if not model:
  674. raise RuntimeError(f"Failed estimation: {model}")
  675. References
  676. ----------
  677. .. [1] Halir, R.; Flusser, J. "Numerically stable direct least squares
  678. fitting of ellipses". In Proc. 6th International Conference in
  679. Central Europe on Computer Graphics and Visualization.
  680. WSCG (Vol. 98, pp. 125-132).
  681. """
  682. return super().from_estimate(data)
  683. def _estimate(self, data, warn_only=True):
  684. # Original Implementation: Ben Hammel, Nick Sullivan-Molina
  685. # another REFERENCE: [2] http://mathworld.wolfram.com/Ellipse.html
  686. _check_data_dim(data, dim=2)
  687. if len(data) < 5:
  688. return _warn_or_msg(
  689. "Need at least 5 data points to estimate an ellipse.",
  690. warn_only=warn_only,
  691. )
  692. # to prevent integer overflow, cast data to float, if it isn't already
  693. float_type = np.promote_types(data.dtype, np.float32)
  694. data = data.astype(float_type, copy=False)
  695. # normalize value range to avoid misfitting due to numeric errors if
  696. # the relative distances are small compared to absolute distances
  697. origin = data.mean(axis=0)
  698. data = data - origin
  699. scale = data.std()
  700. if scale < np.finfo(float_type).tiny:
  701. return _warn_or_msg(
  702. "Standard deviation of data is too small to estimate "
  703. "ellipse with meaningful precision.",
  704. warn_only=warn_only,
  705. )
  706. data /= scale
  707. x = data[:, 0]
  708. y = data[:, 1]
  709. # Quadratic part of design matrix [eqn. 15] from [1]
  710. D1 = np.vstack([x**2, x * y, y**2]).T
  711. # Linear part of design matrix [eqn. 16] from [1]
  712. D2 = np.vstack([x, y, np.ones_like(x)]).T
  713. # forming scatter matrix [eqn. 17] from [1]
  714. S1 = D1.T @ D1
  715. S2 = D1.T @ D2
  716. S3 = D2.T @ D2
  717. # Constraint matrix [eqn. 18]
  718. C1 = np.array([[0.0, 0.0, 2.0], [0.0, -1.0, 0.0], [2.0, 0.0, 0.0]])
  719. try:
  720. # Reduced scatter matrix [eqn. 29]
  721. M = inv(C1) @ (S1 - S2 @ inv(S3) @ S2.T)
  722. except np.linalg.LinAlgError: # LinAlgError: Singular matrix
  723. return 'Singular matrix from estimation'
  724. # M*|a b c >=l|a b c >. Find eigenvalues and eigenvectors
  725. # from this equation [eqn. 28]
  726. eig_vals, eig_vecs = np.linalg.eig(M)
  727. # eigenvector must meet constraint 4ac - b^2 to be valid.
  728. cond = 4 * np.multiply(eig_vecs[0, :], eig_vecs[2, :]) - np.power(
  729. eig_vecs[1, :], 2
  730. )
  731. a1 = eig_vecs[:, (cond > 0)]
  732. # seeks for empty matrix
  733. if 0 in a1.shape or len(a1.ravel()) != 3:
  734. return 'Eigenvector constraints not met'
  735. a, b, c = a1.ravel()
  736. # |d f g> = -S3^(-1)*S2^(T)*|a b c> [eqn. 24]
  737. a2 = -inv(S3) @ S2.T @ a1
  738. d, f, g = a2.ravel()
  739. # eigenvectors are the coefficients of an ellipse in general form
  740. # a*x^2 + 2*b*x*y + c*y^2 + 2*d*x + 2*f*y + g = 0 (eqn. 15) from [2]
  741. b /= 2.0
  742. d /= 2.0
  743. f /= 2.0
  744. # finding center of ellipse [eqn.19 and 20] from [2]
  745. x0 = (c * d - b * f) / (b**2.0 - a * c)
  746. y0 = (a * f - b * d) / (b**2.0 - a * c)
  747. # Find the semi-axes lengths [eqn. 21 and 22] from [2]
  748. numerator = a * f**2 + c * d**2 + g * b**2 - 2 * b * d * f - a * c * g
  749. term = np.sqrt((a - c) ** 2 + 4 * b**2)
  750. denominator1 = (b**2 - a * c) * (term - (a + c))
  751. denominator2 = (b**2 - a * c) * (-term - (a + c))
  752. width = np.sqrt(2 * numerator / denominator1)
  753. height = np.sqrt(2 * numerator / denominator2)
  754. # angle of counterclockwise rotation of major-axis of ellipse
  755. # to x-axis [eqn. 23] from [2].
  756. phi = 0.5 * np.arctan((2.0 * b) / (a - c))
  757. if a > c:
  758. phi += 0.5 * np.pi
  759. # stabilize parameters:
  760. # sometimes small fluctuations in data can cause
  761. # height and width to swap
  762. if width < height:
  763. width, height = height, width
  764. phi += np.pi / 2
  765. phi %= np.pi
  766. # Revert normalization and set parameters.
  767. params = np.nan_to_num([x0, y0, width, height, phi]).real
  768. params[:4] *= scale
  769. params[:2] += origin
  770. self.center, self.axis_lengths, self.theta = (
  771. params[:2],
  772. params[2:4],
  773. params[-1],
  774. )
  775. return None
  776. def residuals(self, data):
  777. """Determine residuals of data to model.
  778. For each point the shortest distance to the ellipse is returned.
  779. Parameters
  780. ----------
  781. data : (N, 2) array
  782. N points with ``(x, y)`` coordinates, respectively.
  783. Returns
  784. -------
  785. residuals : (N,) array
  786. Residual for each data point.
  787. """
  788. _check_data_dim(data, dim=2)
  789. xc, yc = self.center
  790. a, b = self.axis_lengths
  791. theta = self.theta
  792. ctheta = math.cos(theta)
  793. stheta = math.sin(theta)
  794. x = data[:, 0]
  795. y = data[:, 1]
  796. N = data.shape[0]
  797. def fun(t, xi, yi):
  798. ct = math.cos(np.squeeze(t))
  799. st = math.sin(np.squeeze(t))
  800. xt = xc + a * ctheta * ct - b * stheta * st
  801. yt = yc + a * stheta * ct + b * ctheta * st
  802. return (xi - xt) ** 2 + (yi - yt) ** 2
  803. # def Dfun(t, xi, yi):
  804. # ct = math.cos(t)
  805. # st = math.sin(t)
  806. # xt = xc + a * ctheta * ct - b * stheta * st
  807. # yt = yc + a * stheta * ct + b * ctheta * st
  808. # dfx_t = - 2 * (xi - xt) * (- a * ctheta * st
  809. # - b * stheta * ct)
  810. # dfy_t = - 2 * (yi - yt) * (- a * stheta * st
  811. # + b * ctheta * ct)
  812. # return [dfx_t + dfy_t]
  813. residuals = np.empty((N,), dtype=np.float64)
  814. # initial guess for parameter t of closest point on ellipse
  815. t0 = np.arctan2(y - yc, x - xc) - theta
  816. # determine shortest distance to ellipse for each point
  817. for i in range(N):
  818. xi = x[i]
  819. yi = y[i]
  820. # faster without Dfun, because of the python overhead
  821. t, _ = optimize.leastsq(fun, t0[i], args=(xi, yi))
  822. residuals[i] = np.sqrt(fun(t, xi, yi))
  823. return residuals
  824. @_deprecate_model_params
  825. def predict_xy(self, t, params=DEPRECATED):
  826. """Predict x- and y-coordinates using the estimated model.
  827. Parameters
  828. ----------
  829. t : array
  830. Angles in circle in radians. Angles start to count from positive
  831. x-axis to positive y-axis in a right-handed system.
  832. Returns
  833. -------
  834. xy : (..., 2) array
  835. Predicted x- and y-coordinates.
  836. Other parameters
  837. ----------------
  838. params : `~.DEPRECATED`, optional
  839. Optional ellipse model parameters in the following order ``xc``,
  840. ``yc``, `a`, `b`, `theta`.
  841. .. deprecated:: {{ start_version }}
  842. """
  843. t = np.asanyarray(t)
  844. (xc, yc), (a, b), theta = self._get_init_values(params)
  845. ct = np.cos(t)
  846. st = np.sin(t)
  847. ctheta = math.cos(theta)
  848. stheta = math.sin(theta)
  849. x = xc + a * ctheta * ct - b * stheta * st
  850. y = yc + a * stheta * ct + b * ctheta * st
  851. return np.concatenate((x[..., None], y[..., None]), axis=t.ndim)
  852. @_deprecate_estimate
  853. def estimate(self, data):
  854. """Estimate ellipse model from data using total least squares.
  855. Parameters
  856. ----------
  857. data : (N, 2) array
  858. N points with ``(x, y)`` coordinates, respectively.
  859. Returns
  860. -------
  861. success : bool
  862. True, if model estimation succeeds.
  863. References
  864. ----------
  865. .. [1] Halir, R.; Flusser, J. "Numerically stable direct least squares
  866. fitting of ellipses". In Proc. 6th International Conference in
  867. Central Europe on Computer Graphics and Visualization.
  868. WSCG (Vol. 98, pp. 125-132).
  869. """
  870. return self._estimate(data) is None
  871. def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
  872. """Determine number trials such that at least one outlier-free subset is
  873. sampled for the given inlier/outlier ratio.
  874. Parameters
  875. ----------
  876. n_inliers : int
  877. Number of inliers in the data.
  878. n_samples : int
  879. Total number of samples in the data.
  880. min_samples : int
  881. Minimum number of samples chosen randomly from original data.
  882. probability : float
  883. Probability (confidence) that one outlier-free sample is generated.
  884. Returns
  885. -------
  886. trials : int
  887. Number of trials.
  888. """
  889. if probability == 0:
  890. return 0
  891. if n_inliers == 0:
  892. return np.inf
  893. inlier_ratio = n_inliers / n_samples
  894. nom = 1 - probability
  895. denom = 1 - inlier_ratio**min_samples
  896. # Keep (de-)nominator in the range of [_EPSILON, 1 - _EPSILON] so that
  897. # it is always guaranteed that the logarithm is negative and we return
  898. # a positive number of trials.
  899. nom = np.clip(nom, a_min=_EPSILON, a_max=1 - _EPSILON)
  900. denom = np.clip(denom, a_min=_EPSILON, a_max=1 - _EPSILON)
  901. return np.ceil(np.log(nom) / np.log(denom))
  902. def add_from_estimate(cls):
  903. """Add ``from_estimate`` method class using ``estimate`` method"""
  904. if hasattr(cls, 'from_estimate'):
  905. if not inspect.ismethod(cls.from_estimate):
  906. raise TypeError(f'Class {cls} `from_estimate` must be a ' 'class method.')
  907. return cls
  908. if not hasattr(cls, 'estimate'):
  909. raise TypeError(
  910. f'Class {cls} must have `from_estimate` class method '
  911. 'or `estimate` method.'
  912. )
  913. warn(
  914. "Passing custom classes without `from_estimate` has been deprecated "
  915. "since version 0.26 and will be removed in version 2.2. "
  916. "Add `from_estimate` class method to custom class to avoid this "
  917. "warning.",
  918. category=FutureWarning,
  919. stacklevel=3,
  920. )
  921. class FromEstimated(cls):
  922. @classmethod
  923. def from_estimate(klass, *args, **kwargs):
  924. # Assume we can make default instance without input arguments.
  925. instance = klass()
  926. success = instance.estimate(*args, **kwargs)
  927. return (
  928. instance
  929. if success
  930. else FailedEstimation(f'`{cls.__name__}` estimation failed')
  931. )
  932. return FromEstimated
  933. def ransac(
  934. data,
  935. model_class,
  936. min_samples,
  937. residual_threshold,
  938. is_data_valid=None,
  939. is_model_valid=None,
  940. max_trials=100,
  941. stop_sample_num=np.inf,
  942. stop_residuals_sum=0,
  943. stop_probability=1,
  944. rng=None,
  945. initial_inliers=None,
  946. ):
  947. """Fit a model to data with the RANSAC (random sample consensus) algorithm.
  948. RANSAC is an iterative algorithm for the robust estimation of parameters
  949. from a subset of inliers from the complete data set. Each iteration
  950. performs the following tasks:
  951. 1. Select `min_samples` random samples from the original data and check
  952. whether the set of data is valid (see `is_data_valid`).
  953. 2. Estimate a model to the random subset
  954. (`model_cls.from_estimate(*data[random_subset]`) and check whether the
  955. estimated model is valid (see `is_model_valid`).
  956. 3. Classify all data as inliers or outliers by calculating the residuals
  957. to the estimated model (`model_cls.residuals(*data)`) - all data samples
  958. with residuals smaller than the `residual_threshold` are considered as
  959. inliers.
  960. 4. Save estimated model as best model if number of inlier samples is
  961. maximal. In case the current estimated model has the same number of
  962. inliers, it is only considered as the best model if it has less sum of
  963. residuals.
  964. These steps are performed either a maximum number of times or until one of
  965. the special stop criteria are met. The final model is estimated using all
  966. inlier samples of the previously determined best model.
  967. Parameters
  968. ----------
  969. data : list or tuple or array of shape (N,)
  970. Data set to which the model is fitted, where N is the number of data
  971. points and the remaining dimension are depending on model requirements.
  972. If the model class requires multiple input data arrays (e.g. source and
  973. destination coordinates of ``skimage.transform.AffineTransform``),
  974. they can be optionally passed as tuple or list. Note, that in this case
  975. the functions ``estimate(*data)``, ``residuals(*data)``,
  976. ``is_model_valid(model, *random_data)`` and
  977. ``is_data_valid(*random_data)`` must all take each data array as
  978. separate arguments.
  979. model_class : type
  980. Class with the following methods:
  981. * Either:
  982. * ``from_estimate`` class method returning transform instance, as in
  983. ``tform = model_class.from_estimate(*data)``; the resulting
  984. ``tform`` should be truthy (``bool(tform) == True``) where
  985. estimation succeeded, or falsey (``bool(tform) == False``) where it
  986. failed; OR
  987. * (deprecated) ``estimate`` instance method, returning flag to
  988. indicate successful estimation, as in ``tform = model_class();
  989. success = tform.estimate(*data)``. ``success == True`` when
  990. estimation succeeded, ``success == False`` when it failed.
  991. * ``residuals(*data)``
  992. Your model should conform to the ``RansacModelProtocol`` — meaning
  993. implement all of the methods / attributes specified by the
  994. :class:``RansacModelProctocol``. An easy check to see whether that is
  995. the case is to use ``isinstance(MyModel, RansacModelProtocol)``. See
  996. https://docs.python.org/3/library/typing.html#typing.Protocol for more
  997. details.
  998. min_samples : int, in range (0, N)
  999. The minimum number of data points to fit a model to.
  1000. residual_threshold : float, >0
  1001. Maximum distance for a data point to be classified as an inlier.
  1002. is_data_valid : Callable, optional
  1003. This function is called with the randomly selected data before the
  1004. model is fitted to it: `is_data_valid(*random_data)`.
  1005. is_model_valid : Callable, optional
  1006. This function is called with the estimated model and the randomly
  1007. selected data: `is_model_valid(model, *random_data)`, .
  1008. max_trials : int, optional
  1009. Maximum number of iterations for random sample selection.
  1010. stop_sample_num : int, optional
  1011. Stop iteration if at least this number of inliers are found.
  1012. stop_residuals_sum : float, optional
  1013. Stop iteration if sum of residuals is less than or equal to this
  1014. threshold.
  1015. stop_probability : float, optional, in range [0, 1]
  1016. RANSAC iteration stops if at least one outlier-free set of the
  1017. training data is sampled with ``probability >= stop_probability``,
  1018. depending on the current best model's inlier ratio and the number
  1019. of trials. This requires to generate at least N samples (trials):
  1020. N >= log(1 - probability) / log(1 - e**m)
  1021. where the probability (confidence) is typically set to a high value
  1022. such as 0.99, e is the current fraction of inliers w.r.t. the
  1023. total number of samples, and m is the min_samples value.
  1024. rng : {`numpy.random.Generator`, int}, optional
  1025. Pseudo-random number generator.
  1026. By default, a PCG64 generator is used (see :func:`numpy.random.default_rng`).
  1027. If `rng` is an int, it is used to seed the generator.
  1028. initial_inliers : array-like of bool, shape (N,), optional
  1029. Initial samples selection for model estimation
  1030. Returns
  1031. -------
  1032. model : object
  1033. Best model with largest consensus set.
  1034. inliers : (N,) array
  1035. Boolean mask of inliers classified as ``True``.
  1036. References
  1037. ----------
  1038. .. [1] "RANSAC", Wikipedia, https://en.wikipedia.org/wiki/RANSAC
  1039. Examples
  1040. --------
  1041. Generate ellipse data without tilt and add noise:
  1042. >>> t = np.linspace(0, 2 * np.pi, 50)
  1043. >>> xc, yc = 20, 30
  1044. >>> a, b = 5, 10
  1045. >>> x = xc + a * np.cos(t)
  1046. >>> y = yc + b * np.sin(t)
  1047. >>> data = np.column_stack([x, y])
  1048. >>> rng = np.random.default_rng(203560) # do not copy this value
  1049. >>> data += rng.normal(size=data.shape)
  1050. Add some faulty data:
  1051. >>> data[0] = (100, 100)
  1052. >>> data[1] = (110, 120)
  1053. >>> data[2] = (120, 130)
  1054. >>> data[3] = (140, 130)
  1055. Estimate ellipse model using all available data:
  1056. >>> model = EllipseModel.from_estimate(data)
  1057. >>> np.round(model.center)
  1058. array([71., 75.])
  1059. >>> np.round(model.axis_lengths)
  1060. array([77., 13.])
  1061. >>> np.round(model.theta)
  1062. 1.0
  1063. Next we estimate an ellipse model using RANSAC.
  1064. Note that the results are not deterministic, because the RANSAC algorithm
  1065. uses some randomness. If you need the results to be deterministic, pass a
  1066. seeded number generator with the ``rng`` argument to ``ransac``.
  1067. >>> ransac_model, inliers = ransac(data, EllipseModel, 20, 3, max_trials=50)
  1068. >>> np.abs(np.round(ransac_model.center)) # doctest: +SKIP
  1069. array([20., 30.])
  1070. >>> np.abs(np.round(ransac_model.axis_lengths)) # doctest: +SKIP
  1071. array([10., 6.])
  1072. >>> np.abs(np.round(ransac_model.theta)) # doctest: +SKIP
  1073. 2.0
  1074. >>> inliers # doctest: +SKIP
  1075. array([False, False, False, False, True, True, True, True, True,
  1076. True, True, True, True, True, True, True, True, True,
  1077. True, True, True, True, True, True, True, True, True,
  1078. True, True, True, True, True, True, True, True, True,
  1079. True, True, True, True, True, True, True, True, True,
  1080. True, True, True, True, True], dtype=bool)
  1081. >>> sum(inliers) > 40
  1082. True
  1083. RANSAC can be used to robustly estimate a geometric
  1084. transformation. In this section, we also show how to use a
  1085. proportion of the total samples, rather than an absolute number.
  1086. >>> from skimage.transform import SimilarityTransform
  1087. >>> rng = np.random.default_rng()
  1088. >>> src = 100 * rng.random((50, 2))
  1089. >>> model0 = SimilarityTransform(scale=0.5, rotation=1,
  1090. ... translation=(10, 20))
  1091. >>> dst = model0(src)
  1092. >>> dst[0] = (10000, 10000)
  1093. >>> dst[1] = (-100, 100)
  1094. >>> dst[2] = (50, 50)
  1095. >>> ratio = 0.5 # use half of the samples
  1096. >>> min_samples = int(ratio * len(src))
  1097. >>> model, inliers = ransac(
  1098. ... (src, dst),
  1099. ... SimilarityTransform,
  1100. ... min_samples,
  1101. ... 10,
  1102. ... initial_inliers=np.ones(len(src), dtype=bool),
  1103. ... ) # doctest: +SKIP
  1104. >>> inliers # doctest: +SKIP
  1105. array([False, False, False, True, True, True, True, True, True,
  1106. True, True, True, True, True, True, True, True, True,
  1107. True, True, True, True, True, True, True, True, True,
  1108. True, True, True, True, True, True, True, True, True,
  1109. True, True, True, True, True, True, True, True, True,
  1110. True, True, True, True, True])
  1111. """
  1112. best_inlier_num = 0
  1113. best_inlier_residuals_sum = np.inf
  1114. best_inliers = []
  1115. validate_model = is_model_valid is not None
  1116. validate_data = is_data_valid is not None
  1117. rng = np.random.default_rng(rng)
  1118. # in case data is not pair of input and output, male it like it
  1119. if not isinstance(data, (tuple, list)):
  1120. data = (data,)
  1121. num_samples = len(data[0])
  1122. if not (0 < min_samples <= num_samples):
  1123. raise ValueError(f"`min_samples` must be in range (0, {num_samples}]")
  1124. if residual_threshold < 0:
  1125. raise ValueError("`residual_threshold` must be greater than zero")
  1126. if max_trials < 0:
  1127. raise ValueError("`max_trials` must be greater than zero")
  1128. if not (0 <= stop_probability <= 1):
  1129. raise ValueError("`stop_probability` must be in range [0, 1]")
  1130. if initial_inliers is not None and len(initial_inliers) != num_samples:
  1131. raise ValueError(
  1132. f"RANSAC received a vector of initial inliers (length "
  1133. f"{len(initial_inliers)}) that didn't match the number of "
  1134. f"samples ({num_samples}). The vector of initial inliers should "
  1135. f"have the same length as the number of samples and contain only "
  1136. f"True (this sample is an initial inlier) and False (this one "
  1137. f"isn't) values."
  1138. )
  1139. # for the first run use initial guess of inliers
  1140. spl_idxs = (
  1141. initial_inliers
  1142. if initial_inliers is not None
  1143. else rng.choice(num_samples, min_samples, replace=False)
  1144. )
  1145. # Ensure model_class has from_estimate class method.
  1146. model_class = add_from_estimate(model_class)
  1147. # Check protocol.
  1148. if not isinstance(model_class, RansacModelProtocol):
  1149. raise TypeError(
  1150. f"`model_class` {model_class} should be of (protocol) type "
  1151. "RansacModelProtocol"
  1152. )
  1153. num_trials = 0
  1154. # max_trials can be updated inside the loop, so this cannot be a for-loop
  1155. while num_trials < max_trials:
  1156. num_trials += 1
  1157. # do sample selection according data pairs
  1158. samples = [d[spl_idxs] for d in data]
  1159. # for next iteration choose random sample set and be sure that
  1160. # no samples repeat
  1161. spl_idxs = rng.choice(num_samples, min_samples, replace=False)
  1162. # optional check if random sample set is valid
  1163. if validate_data and not is_data_valid(*samples):
  1164. continue
  1165. model = model_class.from_estimate(*samples)
  1166. # backwards compatibility
  1167. if not model:
  1168. continue
  1169. # optional check if estimated model is valid
  1170. if validate_model and not is_model_valid(model, *samples):
  1171. continue
  1172. residuals = np.abs(model.residuals(*data))
  1173. # consensus set / inliers
  1174. inliers = residuals < residual_threshold
  1175. residuals_sum = residuals.dot(residuals)
  1176. # choose as new best model if number of inliers is maximal
  1177. inliers_count = np.count_nonzero(inliers)
  1178. if (
  1179. # more inliers
  1180. inliers_count > best_inlier_num
  1181. # same number of inliers but less "error" in terms of residuals
  1182. or (
  1183. inliers_count == best_inlier_num
  1184. and residuals_sum < best_inlier_residuals_sum
  1185. )
  1186. ):
  1187. best_inlier_num = inliers_count
  1188. best_inlier_residuals_sum = residuals_sum
  1189. best_inliers = inliers
  1190. max_trials = min(
  1191. max_trials,
  1192. _dynamic_max_trials(
  1193. best_inlier_num, num_samples, min_samples, stop_probability
  1194. ),
  1195. )
  1196. if (
  1197. best_inlier_num >= stop_sample_num
  1198. or best_inlier_residuals_sum <= stop_residuals_sum
  1199. ):
  1200. break
  1201. # estimate final model using all inliers
  1202. if any(best_inliers):
  1203. # select inliers for each data array
  1204. data_inliers = [d[best_inliers] for d in data]
  1205. model = model_class.from_estimate(*data_inliers)
  1206. if validate_model and not is_model_valid(model, *data_inliers):
  1207. warn("Estimated model is not valid. Try increasing max_trials.")
  1208. else:
  1209. model = None
  1210. best_inliers = None
  1211. warn("No inliers found. Model not fitted")
  1212. # Return model from wrapper, otherwise model itself.
  1213. return getattr(model, 'model', model), best_inliers