Usage in 15 minutes#
Training One Model for All Tasks#
1. Define the tasks#
OFASys can co-train multiple multi-modal tasks flexibly.
>>> from ofasys import Task, Trainer, GeneralistModel
... task1 = Task(
... name='caption',
... instruction='[IMAGE:image_url] what does the image describe? -> [TEXT:caption]',
... micro_batch_size=4,
... )
>>> task2 = Task(
... name='text_infilling',
... instruction='what is the complete text of " [TEXT:sentence,mask_ratio=0.3] "? -> [TEXT:sentence]',
... micro_batch_size=2,
... )
In the simplest scenario, you only need to specify an instruction to define your task and a task name as an identifier. For more details about how to define a task for training, see Define a Task and Train a Task.
2. Set the Dataset#
The Task can use a regular Pytorch Dataloader which can be constructed by Huggingface Dataset or a customized Pytorch Dataset.
>>> from datasets import load_dataset
>>> task1.add_dataset(load_dataset('TheFusion21/PokemonCards')['train'], 'train')
>>> task2.add_dataset(load_dataset('glue', 'cola')['train'], 'train')
3. Create a Generalist Model and Train all Tasks Together#
The GeneralistModel of OFASys (OFA+) is capable of handling multiple modalities including: TEXT, IMAGE, AUDIO, VIDEO, MOTION, BOX, PHONE.
The OFASys Trainer “mixes” multiple Tasks with any dataset and abstracts away all the engineering complexity needed for scale.
>>> model = GeneralistModel()
>>> trainer = Trainer()
>>> trainer.fit(model=model, tasks=[task1, task2])
The complete script is available at scripts/trainer_api.py. More details on how to write YAML files to define tasks and more distributed usage can be found in Train a Task.
Inference with All Kinds of Tasks with One Checkpoint#
OFASys can infer multiple multi-modal tasks using just One checkpoint.
Load a multi-task checkpoint
>>> from ofasys import OFASys
>>> model = OFASys.from_pretrained('http://ofasys.oss-cn-zhangjiakou.aliyuncs.com/model_hub/multitask_10k.pt')
>>> model = model.cuda() # Omit this line if you don't have a GPU
OFASys enables multi-task multi-modal inference through the instruction alone. Let’s go through a couple of examples!
Image Captioning#
>>> instruction = '[IMAGE:img] what does the image describe? -> [TEXT:cap]'
>>> data = {'img': "https://ofasys.oss-cn-zhangjiakou.aliyuncs.com/data/coco/2014/val2014/COCO_val2014_000000222628.jpg"}
>>> output = model.inference(instruction, data=data)
>>> print(output.text)
"a man and woman sitting in front of a laptop computer"
Visual Grounding#
>>> instruction = '[IMAGE:img] which region does the text " [TEXT:cap] " describe? -> [BOX:patch_boxes]'
>>> data = [
... {'img': "https://www.2008php.com/2014_Website_appreciate/2015-06-22/20150622131649.jpg", 'cap': 'hand'},
... {'img': "http://ofasys.oss-cn-zhangjiakou.aliyuncs.com/data/coco/2014/train2014/COCO_train2014_000000581563.jpg", 'cap': 'taxi'},
... ]
>>> output = model.inference(instruction, data=data)
>>> for i, out in enumerate(output):
... out.save_box(f'{i}.jpg')
Text Summarization#
>>> instruction = 'what is the summary of article " [TEXT:src] "? -> [TEXT:tgt]'
>>> data = {'src': "poland 's main opposition party tuesday endorsed president lech walesa in an upcoming "
... "presidential run-off election after a reformed communist won the first round of voting ."}
>>> output = model.inference(instruction, data=data)
>>> print(output.text)
"polish opposition endorses walesa in presidential run-off"
Table-to-Text Generation#
Atlanta |
OFFICIAL_POPULATION |
5,457,831 |
[TABLECONTEXT] |
METROPOLITAN_AREA |
Atlanta |
5,457,831 |
YEAR |
2012 |
[TABLECONTEXT] |
[TITLE] |
List of metropolitan areas by population |
Atlanta |
COUNTRY |
United States |
>>> instruction = 'structured knowledge: " [STRUCT:database,uncased] " . how to describe the tripleset ? -> [TEXT:tgt]'
>>> data = {
... 'database': [['Atlanta', 'OFFICIAL_POPULATION', '5,457,831'],
... ['[TABLECONTEXT]', 'METROPOLITAN_AREA', 'Atlanta'],
... ['5,457,831', 'YEAR', '2012'],
... ['[TABLECONTEXT]', '[TITLE]', 'List of metropolitan areas by population'],
... ['Atlanta', 'COUNTRY', 'United States'],
... ]
... }
>>> output = model.inference(instruction, data=data, beam_size=1)
>>> print(output.text)
"atlanta, united states has a population of 5,457,831 in 2012."
Text-to-SQL Generation#
Database: concert_singer |
|
|---|---|
Table |
Fields |
stadium |
stadium_id, location, name, capacity, highest, lowest, average |
singer |
singer_id, name, country, song_name, song_release_year, age, is_male |
concert |
concert_id, concert_name, theme, stadium_id, year |
singer_in_concert |
concert_id, singer_id |
>>> instruction = '" [TEXT:src] " ; structured knowledge: " [STRUCT:database,max_length=876] " . generating sql code. -> [TEXT:tgt]'
>>> database = [
... ['concert_singer'],
... ['stadium', 'stadium_id , location , name , capacity , highest , lowest , average'],
... ['singer', 'singer_id , name , country , song_name , song_release_year , age , is_male'],
... ['concert', 'concert_id , concert_name , theme , stadium_id , year'],
... ['singer_in_concert', 'concert_id , singer_id']
... ]
>>> data = [
... {'src': 'What are the names, countries, and ages for every singer in descending order of age?', 'database': database},
... {'src': 'What are all distinct countries where singers above age 20 are from?', 'database': database},
... {'src': 'Show the name and the release year of the song by the youngest singer.', 'database': database}
... ]
>>> output = model.inference(instruction, data=data)
>>> print('\n'.join(o.text for o in output))
"select name, country, age from singer order by age desc"
"select distinct country from singer where age > 20"
"select song_name, song_release_year from singer order by age limit 1"
Video Captioning#
>>> instruction = '[VIDEO:video] what does the video describe? -> [TEXT:cap]'
>>> data = {'video': 'oss://ofasys/datasets/msrvtt_data/videos/video7021.mp4'}
>>> output = model.inference(instruction, data=data)
>>> print(output.text)
"a baseball player is hitting a ball"
Speech-to-Text Generation#
>>> instruction = '[AUDIO:wav] what is the text corresponding to the voice? -> [TEXT:text,preprocess=text_phone]'
>>> data = {'wav': 'oss://ofasys/data/librispeech/dev-clean/1272/128104/1272-128104-0001.flac'}
>>> output = model.inference(instruction, data=data)
>>> print(output.text)
"nor is mister klohs manner less interesting than his manner"
Text-to-Image Generation#
>>> instruction = 'what is the complete image? caption: [TEXT:text]"? -> [IMAGE,preprocess=image_vqgan,adaptor=image_vqgan]'
>>> data = {'text': "a city with tall buildings and a large green park."}
>>> output = model.inference(instruction, data=data)
>>> output[0].save_image('0.png')
The complete script is available at scripts/inference_multiple_task.py