Package linear_operators :: Package iterative :: Module criterions
[hide private]
[frames] | no frames]

Source Code for Module linear_operators.iterative.criterions

  1  """ 
  2  Implement the criterion class. Available : 
  3   
  4  - Criterion 
  5  - QuadraticCriterion 
  6  - HuberCriterion 
  7  """ 
  8  from copy import copy 
  9  import numpy as np 
 10  from norms import * 
 11  from ..interface import concatenate 
 12   
13 -class Criterion(object):
14 """ 15 A class representing criterions such as : 16 ..math: || y - H x ||^2 + sum_i \hypers_i || D_i x ||^2 17 18 Parameters 19 ---------- 20 21 H : LinearOperator 22 Model. 23 y : ndarray 24 Data array. 25 hypers: tuple or array 26 Hypeparameter of each prior. 27 Ds : tuple of LinearOperators 28 Prior models. 29 norms : tuple 30 Can be norm2, huber, normp 31 store : boolean (default: True) 32 Store last criterion computation. 33 34 Returns 35 ------- 36 37 Returns an Criterion instance with __call__ and gradient methods. 38 """
39 - def __init__(self, model, data, hypers=[], priors=[], norms=[], store=True):
40 self.model = model 41 self.data = data.ravel() 42 self.priors = priors 43 # normalize hyperparameters 44 self.hypers = np.asarray(hypers) * model.shape[0] / float(model.shape[1]) 45 # default to Norm2 46 self.norms = norms 47 if len(self.norms) == 0: 48 self.norms = (Norm2(), ) * (len(self.priors) + 1) 49 # get diff of norms 50 self.dnorms = [n.diff for n in self.norms] 51 # to store intermediate values 52 self.store = store 53 self.last_x = None 54 self.projection = None 55 self.last_residual = None 56 self.last_prior_residuals = None 57 # if all norms are l2 define optimal step 58 self._optimal_step = np.all([n == Norm2 for n in self.norms]) 59 # set number of unknowns 60 self.n_variables = self.model.shape[1]
61 - def islast(self, x):
62 return np.all(x == self.last_x)
63 - def load_last(self):
64 return self.projection, self.last_residual, self.last_prior_residuals
65 - def get_residuals(self, x):
66 if self.islast(x): 67 Hx, r, rd = self.load_last() 68 else: 69 Hx = self.model * x 70 r = Hx - self.data 71 rd = [D * x for D in self.priors] 72 return Hx, r, rd
73 - def save(self, x, Hx, r, rd):
74 if self.store and not self.islast(x): 75 self.last_x = copy(x) 76 self.Hx = copy(Hx) 77 self.last_residual = copy(r) 78 self.last_prior_residuals = copy(rd)
79 - def __call__(self, x):
80 # residuals 81 Hx, r, rd = self.get_residuals(x) 82 # norms 83 J = self.norms[0](r) 84 priors = [norm(rd_i) for norm, rd_i in zip(self.norms[1:], rd)] 85 J += np.sum([h * prior for h, prior in zip(self.hypers, priors)]) 86 self.save(x, Hx, r, rd) 87 return J
88 - def gradient(self, x):
89 """ 90 First order derivative of the criterion as a function of x. 91 """ 92 Hx, r, rd = self.get_residuals(x) 93 g = self.model.T * self.dnorms[0](r) 94 p_dnorms = [dnorm(el) for dnorm, el in zip(self.dnorms[1:], rd)] 95 p_diff = [D.T * dn for D, dn in zip(self.priors, p_dnorms)] 96 drs = [h * pd for h, pd in zip(self.hypers, p_diff)] 97 for dr in drs: 98 g += dr 99 self.save(x, Hx, r, rd) 100 return g
101
102 -class QuadraticCriterion(Criterion):
103 """ 104 Subclass of Criterion with all norms forced to be Norm2 instances. 105 """
106 - def __init__(self, model, data, hypers=[], priors=[], store=True, 107 hessian=False):
108 norms = (Norm2(), ) * (len(priors) + 1) 109 self.prior = concatenate([h * p for h, p in zip(hypers, priors)]) 110 self._hessian_model = model.T * model + self.prior.T * self.prior 111 if hessian: 112 self.hessian = self._hessian 113 self.hessian_p = self._hessian_p 114 Criterion.__init__(self, model, data, hypers=hypers, priors=priors, store=store)
115 - def _hessian(self, u):
116 return self._hessian_model
117 - def _hessian_p(self, u, p):
118 return self._hessian_model * p
119
120 -class HuberCriterion(Criterion):
121 """ 122 Subclass of Criterion with all norms forced to be Huber instances. 123 """
124 - def __init__(self, model, data, hypers=[], deltas=[], priors=[], store=True):
125 norms = [Huber(d) for d in deltas] 126 Criterion.__init__(self, model, data, hypers=hypers, priors=priors, store=store)
127