Fine-tune Llama 2 on Replicate
Posted by @cbh123
Llama 2 is the first open-source language model of the same caliber as OpenAI’s models, and because it’s open source you can hack it to do new things that aren’t possible with GPT-4.
Like become a better poet. Talk like Homer Simpson. Write Midjourney prompts. Or replace your best friends.
One of the main reasons to fine-tune models is so you can use a small model do a task that would normally require a large model. This means you can do the same task, but cheaper and faster. For example, the 7 billion parameter Llama 2 model is not good at summarizing text, but we can teach it how.
In this guide, we’ll show you how to create a text summarizer. We’ll be using Llama 2 7B, an open-source large language model from Meta and fine-tuning it on a dataset of messenger-like conversations with summaries. When we’re done, you’ll be able to distill chat transcripts, emails, webpages, and other documents into a brief summary. Short and sweet.
Supported models
Here are the Llama models on Replicate that you can fine-tune:
If your model is responding to instructions from users, you want to use the chat models. If you are just completing text, you’ll want to use the base.
Training data
Your training data should be in a JSONL text file.
In this guide, we’ll be using the SAMSum dataset, transformed into JSONL.
Create a model
You need to create an empty model on Replicate for your trained model. When your training finishes, it will be pushed as a new version to this model.
Go to replicate.com/create and create a new model called “llama2-summarizer”.
Authenticate
Authenticate by setting your token in an environment variable:
export REPLICATE_API_TOKEN=<paste-your-token-here>
Find your API token in your account settings.
Create a training
Install the Python library:
pip install replicate
And kick off training, replacing the destination name with your username and the name of your new model:
import replicate
training = replicate.trainings.create(
version="meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
input={
"train_data": "https://gist.githubusercontent.com/nateraw/055c55b000e4c37d43ce8eb142ccc0a2/raw/d13853512fc83e8c656a3e8b6e1270dd3c398e77/samsum.jsonl",
"num_train_epochs": 3,
},
destination="<your-username>/llama2-summarizer"
)
print(training)
It takes these arguments:
version
: The model to train, in the format{username}/{model}:{version}
.input
: The training data and params to pass to the training process, which are defined by the model. Llama 2’s params can be found in the model’s “Train” tab.destination
: The model to push the trained version to, in the formatyour-username/your-model-name
Once you’ve kicked off your training, visit replicate.com/trainings in your browser to monitor the progress.
Run the model
You can now run your model from the web or with an API. To use your model in the browser, go to your model page.
To use your model with an API, run the version
from the training output:
training.reload()
prompt = """[INST] <<SYS>>\
Use the Input to provide a summary of a conversation.
<</SYS>>
Input:
Harry: Who are you?
Hagrid: Rubeus Hagrid, Keeper of Keys and Grounds at Hogwarts. Of course, you know all about Hogwarts.
Harry: Sorry, no.
Hagrid: No? Blimey, Harry, did you never wonder where yer parents learned it all?
Harry: All what?
Hagrid: Yer a wizard, Harry.
Harry: I-- I'm a what?
Hagrid: A wizard! And a thumpin' good 'un, I'll wager, once you've been trained up a bit. [/INST]
Summary: """
output = replicate.run(
training.output["version"],
input={"prompt": prompt, "stop_sequences": "</s>"}
)
for s in output:
print(s, end="", flush=True)
That’s it! You’ve fine-tuned Llama 2 and can run your new model with an API.
Next steps
Happy hacking! 🦙