1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| trainer = transformers.Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, args=transformers.TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=100, num_train_epochs=num_epochs, learning_rate=learning_rate, fp16=True, logging_steps=1, optim="adamw_torch", evaluation_strategy="steps" if val_set_size > 0 else "no", save_strategy="steps", eval_steps=eval_step if val_set_size > 0 else None, save_steps=save_step, output_dir=output_dir, save_total_limit=20, load_best_model_at_end=True if val_set_size > 0 else False, ddp_find_unused_parameters=False if ddp else None, group_by_length=group_by_length, report_to="wandb" if use_wandb else "tensorboard", run_name=wandb_run_name if use_wandb else None, ), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), ) model.config.use_cache = False
old_state_dict = model.state_dict model.state_dict = ( lambda self, *_, **__: get_peft_model_state_dict( self, old_state_dict() ) ).__get__(model, type(model))
trainer.train(resume_from_checkpoint=resume_from_checkpoint) model.save_pretrained(output_dir)
|