Source code for phitodeep.optimization.initialization

from enum import Enum
import numpy as np


[docs] class InitType(Enum): NORMAL = 1 UNIFORM = 2
[docs] class Initializer: def __init__(self, name: str="Base", init_type: InitType=InitType.NORMAL): self.name = name self.init_type = init_type
[docs] def get_scale(self, weights: np.ndarray) -> float: raise NotImplementedError( f"{self.name} initialization must implement get_weights." )
[docs] class Xavier(Initializer): def __init__(self, init_type: InitType=InitType.NORMAL): super().__init__("Xavier", init_type)
[docs] def get_scale(self, weights: np.ndarray) -> float: divident = 2 if self.init_type == InitType.NORMAL else 6 fan_in = weights.shape[0] fan_out = weights.shape[1] return np.sqrt(divident / (fan_in + fan_out))
[docs] class He(Initializer): def __init__(self, init_type: InitType=InitType.NORMAL): super().__init__("He", init_type)
[docs] def get_scale(self, weights: np.ndarray) -> float: divident = 2 if self.init_type == InitType.NORMAL else 6 fan_in = weights.shape[0] return np.sqrt(divident / fan_in)