Source code for banhxeo.model.classic.gru

from typing import Optional

import einops
import torch
import torch.nn as nn
import torch.nn.functional as F

from jaxtyping import Integer
from pydantic import model_validator
from typing_extensions import Self

from banhxeo.core.vocabulary import Vocabulary
from banhxeo.model.classic.rnn import RNNConfig
from banhxeo.model.neural import NeuralLanguageModel, NeuralModelConfig


[docs] class GRUConfig(RNNConfig): ...
[docs] class GRUCell(nn.Module):
[docs] def __init__(self, input_size: int, hidden_size: int, bias: bool): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.update_gate = nn.Linear(input_size * hidden_size, hidden_size, bias=bias) self.reset_gate = nn.Linear(input_size * hidden_size, hidden_size, bias=bias) self.candidate_gate = nn.Linear( input_size * hidden_size, hidden_size, bias=bias )
[docs] def forward( self, input_t: torch.Tensor, h_prev: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size = input_t.shape[0] if h_prev is None: h_prev = torch.zeros( batch_size, self.hidden_size, device=input_t.device, dtype=input_t.dtype ) concat_input = torch.cat([h_prev, input_t], dim=1) update_gate_val = F.sigmoid(self.update_gate(concat_input)) reset_gate_val = F.sigmoid(self.reset_gate(concat_input)) candidate_input = torch.cat([torch.mul(reset_gate_val, h_prev), input_t]) candidate_h_next = F.tanh(self.candidate_gate(candidate_input)) h_next = torch.mul(update_gate_val, candidate_h_next) + torch.mul( (torch.ones_like(update_gate_val) - update_gate_val), h_prev ) return h_next
[docs] class GRU(NeuralLanguageModel): """Use batch_first as default."""
[docs] def __init__( self, vocab: Vocabulary, embedding_dim: int, hidden_size: int, bias: bool = False, ): super().__init__( model_config=GRUConfig( embedding_dim=embedding_dim, hidden_size=hidden_size, bias=bias ), vocab=vocab, ) self.config: GRUConfig self.input_size = self.config.embedding_dim self.hidden_size = hidden_size self.embedding_layer = nn.Embedding( num_embeddings=self.vocab.vocab_size, embedding_dim=self.config.embedding_dim, padding_idx=self.vocab.pad_id, ) if self.config.num_layers > 1: raise ValueError("Stacked RNN isn't implemented now") # self.rnn_cells = nn.ModuleDict() # for layer in range(self.config.num_layers): # self.rnn_cells[f"layer_{layer}"] = RNNCell( # self.input_size, self.hidden_size, self.config.bias # ) else: self.gru_cells = GRUCell( self.input_size, self.hidden_size, self.config.bias )
[docs] def forward( self, input_ids: Integer[torch.Tensor, "batch seq"], # noqa: F722 attention_mask: Optional[Integer[torch.Tensor, "batch seq"]] = None, # noqa: F722 **kwargs, ): outputs_list = [] # get original sequences length (we need cpu tensor for pack_padded_sequence) original_seqs_len = einops.reduce( attention_mask, "batch seq -> batch", "sum" ).to("cpu") # type: ignore # Attention_mask: [batch, seq] # Input_ids: [batch, seq] # Embeddings: [batch, seq, embed_dim] embeddings = self.embedding_layer(input_ids) # Pack input for RNN packed_inputs = nn.utils.rnn.pack_padded_sequence( input=embeddings, lengths=original_seqs_len, # type: ignore batch_first=True, enforce_sorted=False, ) # Then unpack the input (again) inputs, batch_sizes, sorted_indices, unsorted_indices = packed_inputs # Batch size at the first (longest) time step effective_batch_size = batch_sizes[0].item() # Create initial hidden_state h_prev = torch.zeros( effective_batch_size, # type: ignore self.hidden_size, device=inputs.device, dtype=inputs.dtype, ) max_seq_len = batch_sizes.size(0) # For packed sequence, inputs is already (total_tokens_across_all_seqs, input_size) # We need to slice it based on batch_sizes last_processed_idx = 0 for t in range(max_seq_len): # Get the current batch size for this time step current_batch_size = batch_sizes[t].item() # input_t will have shape (current_batch_size, input_size) input_t = inputs[ last_processed_idx : last_processed_idx + current_batch_size ] last_processed_idx += current_batch_size # We need to ensure h_prev is also sliced to current_batch_size # This is important because as sequences end, the effective batch size shrinks h_prev_t = h_prev[:current_batch_size] # Pass through our RNN cell h_next_t = self.gru_cells(input_t, h_prev_t) outputs_list.append(h_next_t) if current_batch_size < h_prev.shape[0]: # If batch size shrunk, carry over old hidden states for sequences that ended # This is tricky. The simplest is to ensure h_prev is always sized for max batch size # and only update the relevant parts. h_prev = torch.cat((h_next_t, h_prev[current_batch_size:]), dim=0) else: h_prev = h_next_t outputs_packed_data = torch.cat(outputs_list, dim=0) # Re-pack the outputs outputs = nn.utils.rnn.PackedSequence( outputs_packed_data, batch_sizes, sorted_indices, unsorted_indices ) # For packed inputs, we need to unsort this h_prev final_hidden_states = h_prev.clone() # It's already for the max_batch_size h_n = einops.rearrange( final_hidden_states[unsorted_indices], "batch hidden -> 1 batch hidden", # PyTorch's nn.RNN h_n output is (num_layers * num_directions, batch_size, hidden_size) ) return {"hidden_states": outputs, "last_hidden_state": h_n}