I’m working on fine-tuning an LLM to build a fantasy football league model. The goal is to have the model output a team with high potential (hopefully) given a round of games. I have built an RAG dataset and implemented custom loss functions and metrics to fine-tune the model.
Problem: Regardless of the model I try to fine-tune, I consistently encounter a CUDA Out Of Memory (OOM) error. I’ve attempted to fine-tune various models, starting with Mistral-7B and going down to models with as few as 410M parameters (EleutherAI/pythia-410m, bigscience/bloomz-560m*)*. However, the OOM issue persists even with smaller models.
Environment Details:
-
EC2 Instance: g5.2xlarge
-
GPU: A100 with 24 GB VRAM
-
CPU RAM: 32 GB
What I Tried:
-
Lowered batch size to 1
-
Add gradient accumulation
-
Mixed precision training
-
QLoRA (even pythia-410m loaded in 4-bit with fine-tuned with LoRA PEFT method crashed with OOM error)
-
Gradient checkpointing
-
Cancel out the RAG pipeline
-
torch.cuda.empty_cache()
Despite these efforts, the OOM error still occurs. Given the hardware, I expected it to handle at least the smaller models without running into memory issues.
Notes:
-
I set max_length=4096 as my input sequences are very long (could be 1000-4000 tokens).
-
I’m using HuggingFace transformers library
I’m attaching my DataCollator and the training function:
class FantasyTeamDataCollator:
def __init__(self, tokenizer, rag_retriever: SeasonSpecificRAG, max_length: int, eval_steps: int):
self.tokenizer = tokenizer
self.rag_retriever = rag_retriever
self.max_length = max_length
self.eval_steps = eval_steps
self.steps = 0
def __call__(self, batch):
teams_batch = [sample['teams'] for sample in batch]
dates_batch = [sample['date'] for sample in batch]
seasons_batch = [sample['season'] for sample in batch]
rag_info_batch = self.rag_retriever.retrieve_relevant_info(teams_batch, dates_batch, seasons_batch)
processed_samples = []
for i, sample in enumerate(batch):
processed_samples.append(self.process_sample(sample, rag_info_batch[i]))
processed_samples = [result for result in processed_samples if result is not None]
if not processed_samples:
raise ValueError("All samples in the batch failed to process")
batch_output = self.collate_batch(processed_samples)
return batch_output
def process_sample(self, sample: Dict[str, Any], rag_info: Dict[str, List[str]]) -> Dict[str, Any]:
combined_input = self.combine_input_with_rag(sample['text'], rag_info)
input_encodings = self.tokenizer(combined_input, truncation=True,
max_length=self.max_length, padding="max_length")
return {
"input_ids": torch.tensor(input_encodings["input_ids"]),
"attention_mask": torch.tensor(input_encodings["attention_mask"]),
"labels": torch.tensor(input_encodings["input_ids"]),
"matches": sample['matches'],
"round": sample['round']
}
def combine_input_with_rag(self, input_text: str, rag_info: Dict[str, List[str]]) -> str:
combined_input = (f"{input_text}nn"
f"Relevant Information:n"
f"Teams Info:{rag_info['teams']}n"
f"Players Info:{rag_info['players']}")
# add system prompts occasionally
if self.steps % self.eval_steps == 0:
combined_input = (f"Instructions: {instruction_prompt}nn"
f"League Rules: {full_rules_prompt}nn"
f"{combined_input}")
self.steps += 1
return combined_input
u/staticmethod
def collate_batch(batch):
return {
"input_ids": torch.stack([item["input_ids"] for item in batch]),
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
"labels": torch.stack([item["labels"] for item in batch]),
"matches": [item["matches"] for item in batch],
"round": [item["round"] for item in batch]
}
-----------------------------------------------------------------------------------------------
def fine_tune(self):
train_dataset = self.fantasy_dataset.dataset_dict['train']
eval_dataset = self.fantasy_dataset.dataset_dict['test']
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=5,
early_stopping_threshold=0.01,
)
training_args = TrainingArguments(
output_dir=self.out_dir,
num_train_epochs=self.num_epochs,
per_device_train_batch_size=self.bz,
per_device_eval_batch_size=self.bz,
gradient_accumulation_steps=self.conf.train.accumulation_steps,
load_best_model_at_end=True,
metric_for_best_model='combined_score',
greater_is_better=True,
eval_strategy='epoch',
eval_steps=self.eval_steps,
save_strategy='epoch',
save_total_limit=10,
fp16=False,
bf16=True,
remove_unused_columns=False,
max_grad_norm=1.0,
gradient_checkpointing=True
)
print('nBegin fine-tuning the model')
trainer = FantasyTrainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=self.data_collator,
compute_metrics=self.compute_metrics,
callbacks=[early_stopping_callback],
fantasy_team_loss=self.fantasy_team_loss,
eval_steps=self.eval_steps,
initial_structure_weight=self.structure_weight,
min_structure_weight=self.min_structure_weight
)
trainer.train()
------------------------------------------------------------------------------------------------
class FantasyTrainer(Trainer):
def __init__(self, *args, **kwargs):
# Extract custom arguments
self.fantasy_team_loss = kwargs.pop('fantasy_team_loss', None)
self.eval_steps = kwargs.pop('eval_steps', 100)
self.structure_weight = kwargs.pop('initial_structure_weight', 1.0)
self.min_structure_weight = kwargs.pop('min_structure_weight', 0.1)
# Initialize Trainer with remaining arguments
super().__init__(*args, **kwargs)
self.steps = 0
self.losses = {
'loss': [],
'lm_loss': [],
'structure_loss': []
}
def compute_loss(self, model, inputs, return_outputs=False):
model_inputs = {k: v for k, v in inputs.items() if k in ['input_ids', 'attention_mask']}
outputs = model(**model_inputs)
# Calculate custom loss
lm_loss, structure_loss = self.fantasy_team_loss(outputs.logits, inputs['input_ids'])
# Combine losses with updated weight
total_loss = lm_loss + (self.structure_weight * structure_loss)
# Add L2 regularization
l2_lambda = 0.01 # Adjust this value as needed
l2_reg = torch.sum(torch.stack([p.pow(2.0).sum() for p in model.parameters()]))
total_loss += l2_lambda * l2_reg
# Update losses
self.losses['loss'].append(total_loss.item())
self.losses['lm_loss'].append(lm_loss.item())
self.losses['structure_loss'].append(structure_loss.item())
# Log metrics every eval_steps
if self.steps % self.eval_steps == 0:
self._log_metrics()
# Decrease structure weight over time
self.structure_weight = np.maximum(self.min_structure_weight, self.structure_weight * 0.9)
self.steps += 1
return (total_loss, outputs) if return_outputs else total_loss
def _move_model_to_device(self, model, device):
pass
def train(self, resume_from_checkpoint: Union[str, bool] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None, **kwargs):
# Reset steps and losses before training
self.steps = 0
self.losses = {k: [] for k in self.losses}
return super().train(resume_from_checkpoint, trial, **kwargs)
Questions:
-
Is the hardware I’m using insufficient for fine-tuning, particularly for models with sequence lengths up to 4096 tokens?
-
Are there additional optimizations or techniques I should consider to mitigate the OOM errors?
Any insights, suggestions, or advice would be greatly appreciated.
Thanks in advance!