@@ -113,12 +113,12 @@ def __init__(self, config, layer_id):
113113
114114 self .layer_id = layer_id
115115 attention_window = config .attention_window [self .layer_id ]
116- assert attention_window % 2 == 0 , (
117- f"` attention_window` for layer { self . layer_id } has to be an even value. Given { attention_window } "
118- )
119- assert attention_window > 0 , (
120- f"` attention_window` for layer { self . layer_id } has to be positive. Given { attention_window } "
121- )
116+ assert (
117+ attention_window % 2 == 0
118+ ), f"`attention_window` for layer { self . layer_id } has to be an even value. Given { attention_window } "
119+ assert (
120+ attention_window > 0
121+ ), f"`attention_window` for layer { self . layer_id } has to be positive. Given { attention_window } "
122122
123123 self .one_sided_attn_window_size = attention_window // 2
124124
@@ -152,9 +152,9 @@ def forward(
152152 value_vectors = self .value (hidden_states )
153153
154154 seq_len , batch_size , embed_dim = hidden_states .size ()
155- assert embed_dim == self . embed_dim , (
156- f"hidden_states should have embed_dim = { self .embed_dim } , but has { embed_dim } "
157- )
155+ assert (
156+ embed_dim == self .embed_dim
157+ ), f"hidden_states should have embed_dim = { self . embed_dim } , but has { embed_dim } "
158158
159159 # normalize query
160160 query_vectors /= math .sqrt (self .head_dim )
@@ -222,9 +222,9 @@ def forward(
222222 ) # use fp32 for numerical stability
223223
224224 if layer_head_mask is not None :
225- assert layer_head_mask .size () == (self . num_heads ,), (
226- f"Head mask for a single layer should be of size { ( self .num_heads ,) } , but is { layer_head_mask . size () } "
227- )
225+ assert layer_head_mask .size () == (
226+ self .num_heads ,
227+ ), f"Head mask for a single layer should be of size { ( self . num_heads ,) } , but is { layer_head_mask . size () } "
228228 attn_probs = layer_head_mask .view (1 , 1 , - 1 , 1 ) * attn_probs
229229
230230 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
@@ -416,9 +416,9 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso
416416 overlap of size window_overlap
417417 """
418418 batch_size , seq_len , num_heads , head_dim = query .size ()
419- assert seq_len % ( window_overlap * 2 ) == 0 , (
420- f"Sequence length should be multiple of { window_overlap * 2 } . Given { seq_len } "
421- )
419+ assert (
420+ seq_len % ( window_overlap * 2 ) == 0
421+ ), f"Sequence length should be multiple of { window_overlap * 2 } . Given { seq_len } "
422422 assert query .size () == key .size ()
423423
424424 chunks_count = torch .div (seq_len , window_overlap , rounding_mode = "trunc" ) - 1
@@ -689,9 +689,9 @@ def _compute_global_attn_output_from_hidden(
689689
690690 # apply layer head masking
691691 if layer_head_mask is not None :
692- assert layer_head_mask .size () == (self . num_heads ,), (
693- f"Head mask for a single layer should be of size { ( self .num_heads ,) } , but is { layer_head_mask . size () } "
694- )
692+ assert layer_head_mask .size () == (
693+ self .num_heads ,
694+ ), f"Head mask for a single layer should be of size { ( self . num_heads ,) } , but is { layer_head_mask . size () } "
695695 global_attn_probs_float = layer_head_mask .view (1 , - 1 , 1 , 1 ) * global_attn_probs_float .view (
696696 batch_size , self .num_heads , max_num_global_attn_indices , seq_len
697697 )
0 commit comments