Search
Duplicate
💻

enformer-pytorch 코드 리뷰

enformer-pytorch
lucidrains

모델 구조

프로젝트 Scaffold 만들기

첫 class 구현

AttentionPool 구현

class AttentionPool(nn.Module): def __init__(self, dim, pool_size = 2): super().__init__() # b=batch size, d=channel, n=L/2, p=2 self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2) self.to_attn_logits = nn.Parameter(torch.eye(dim)) def forward(self, x): attn_logits = einsum('b d n, d e -> b e n', x, self.to_attn_logits) x = self.pool_fn(x) attn = self.pool_fn(attn_logits).softmax(dim = -1) return (x * attn).sum(dim = -1)
Python
복사
einsum 작동 방식?

최종적으로는 AttentionPool이 이렇게 구현된다

class AttentionPool(nn.Module): def __init__(self, dim, pool_size = 2): super().__init__() self.pool_size = pool_size self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size) self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False) # in_channels=dim, out_channels=dim, kernel_size=1, stride=1, bias=False # dim -> dim으로 가는 변환. dim=C 라고 보면 된다. # Conv2d weight initialize가 identity matrix로 되던가? def forward(self, x): b, _, n = x.shape remainder = n % self.pool_size needs_padding = remainder > 0 if needs_padding: x = F.pad(x, (0, remainder), value = 0) mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device) mask = F.pad(mask, (0, remainder), value = True) x = self.pool_fn(x) logits = self.to_attn_logits(x) if needs_padding: mask_value = -torch.finfo(logits.dtype).max logits = logits.masked_fill(self.pool_fn(mask), mask_value) attn = logits.softmax(dim = -1) return (x * attn).sum(dim = -1)
Python
복사
Conv2d weight 초기화 확인해보기
Identity matrix가 아닌데?

Residual class 구현

class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x
Python
복사
이렇게 Residual operation을 따로 클래스로 구현해 두면 nn.Sequential 내에서 이렇게 편리하게 사용 가능하다.