当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

AIGC
后台-插件-广告管理-内容页头部广告(手机)

作者 | 朱祥茹、段忠杰、汪诚愚、黄俊

来源 | 阿里开发者公众号

导读

用户生成内容(User Generated Content,UGC)是互联网上多模态内容的重要组成部分,UGC数据级的不断增长促进了各大多模态内容平台的繁荣。在海量多模态数据和深度学习大模型的加持下,AI生成内容(AI Generated Content,AIGC)呈现出爆发性增长趋势。其中,文图生成(Text-to-image Generation)任务是流行的跨模态生成任务,旨在生成与给定文本对应的图像。典型的文图模型例如OpenAI开发的DALL-E和DALL-E2。近期,业界也训练出了更大、更新的文图生成模型,例如Google提出的Parti和Imagen,基于扩散模型的Stable Diffusion等。

然而,上述模型一般不能用于处理中文的需求,而且上述模型的参数量庞大,很难被开源社区的广大用户直接用来Fine-tune和推理。此外,文图生成模型的训练过程对于知识的理解比较缺乏,容易生成反常识内容。本次,EasyNLP开源框架在先前推出的基于Transformer的文图生成模型(看这里[1])基础上,进一步推出了融合丰富知识图谱知识的文图生成模型ARTIST,能在知识图谱的指引上,生成更加符合常识的图片。我们在中文文图生成评测基准MUGE上评测了ARTIST的生成效果,其生成效果名列榜单第一。我们也向开源社区免费开放了知识增强的中文文图生成模型的Checkpoint,以及相应Fine-tune和推理接口。用户可以在我们开放的Checkpoint基础上进行少量领域相关的微调,在不消耗大量计算资源的情况下,就能一键进行各种艺术创作。

EasyNLP[2]是阿⾥云机器学习PAI 团队基于 PyTorch 开发的易⽤且丰富的中⽂NLP算法框架,⽀持常⽤的中⽂预训练模型和⼤模型落地技术,并且提供了从训练到部署的⼀站式 NLP 开发体验。EasyNLP 提供了简洁的接⼝供⽤户开发 NLP 模型,包括NLP应⽤ AppZoo 和预训练 ModelZoo,同时提供技术帮助⽤户⾼效的落地超⼤预训练模型到业务。由于跨模态理解需求的不断增加,EasyNLP也⽀持各种跨模态模型,特别是中⽂领域的跨模态模型,推向开源社区,希望能够服务更多的 NLP 和多模态算法开发者和研 究者,也希望和社区⼀起推动 NLP /多模态技术的发展和模型落地。

本⽂简要介绍ARTIST的技术解读,以及如何在EasyNLP框架中使⽤ARTIST模型。

ARTIST模型详解

ARTIST模型的构建基于Transformer模型 ,将文图生成任务分为两个阶段进行,第一阶段是通过VQGAN模型对图像进行矢量量化,即对于输入的图像,通过编码器将图像编码为定长的离散序列,解码阶段是以离散序列作为输入,输出重构图。第二阶段是将文本序列和编码后的图像序列作为输入,利用GPT模型学习以文本序列为条件的图像序列生成。为了增强模型先验,我们设计了一个Word Lattice Fusion Layer,将知识图谱中的的实体知识引入模型,辅助图像中对应实体的生成,从而使得生成的图像的实体信息更加精准。下图是ARTIST模型的系统框图,以下从文图生成总体流程和知识注入两方面介绍本方案。

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

第一阶段:基于VQGAN的图像矢量量化

在VQGAN的训练阶段,我们利用数据中的图片,以图像重构为任务目标,训练一个图像词典的codebook,其中,这一codebook保存每个image token的向量表示。实际操作中,对于一张图片,通过CNN Encoder编码后得到中间特征向量,再对特征向量中的每个编码位置寻找codebook中距离最近的表示,从而将图像转换成由codebook中的imaga token表示的离散序列。第二阶段中,GPT模型会以文本为条件生成图像序列,该序列输入到VQGAN Decoder,从而重构出一张图像。


第二阶段:以文本序列为输入利用GPT生成图像序列

为了将知识图谱中的知识融入到文图生成模型中,我们首先通过TransE对中文知识图谱CN-DBpedia进行了训练,得到了知识图谱中的实体表示。在GPT模型训练阶段,对于文本输入,首先识别出所有的实体,然后将已经训练好的实体表示和token embedding进行结合,增强实体表示。但是,由于每个文本token可能属于多个实体,如果将多个实体的表示全都引入模型,可能会造成知识噪声问题。所以我们设计了实体表示交互模块,通过计算每个实体表示和token embedding的交互,为所有实体表示加权,有选择地进行知识注入。特别地,我们计算每个实体表征对对于当前token embedding的重要性,通过内积进行衡量,然后将实体表示的加权平均值注入到当前token embedding中,计算过程如下:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

得到知识注入的token embedding后,我们通过构建具有layer norm的self-attention网络,构建基于Transformer的GPT模型,过程如下:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

在GPT模型的训练阶段,将文本序列和图像序列拼接作为输入,假设文本序列为w, 生成图像的imaga token表示的离散序列概率如下所示:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

最后,模型通过最大化图像部分的负对数似然来训练,得到模型参数的值。

ARTIST模型效果

标准数据集评测结果

我们在多个中文数据集上评估了ARTIST模型的效果,这些数据集的统计数据如下所示:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

在Baseline方面,我们考虑两种情况:zero-shot learning和标准fine-tuning。我们将40亿参数的中文CogView模型作为zero-shot learner,我们也考虑两个模型规模和ARTIST模型规模相当的模型,分别为开源的DALL-E模型和OFA模型。实验数据如下所示:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

从上可以看出,我们的模型在参数量很小的情况(202M)下也能获得较好的图文生成效果。为了衡量注入知识的有效性,我们进一步进行了相关评测,将知识模块移除,实验效果如下:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

上述结果可以清楚地看出知识注入的作用。

案例分析

为了更加直接地比较不同场景下,ARTIST和baseline模型生成图像质量对比,我们展示了电商商品场景和自然风光场景下各个模型生成图像的效果,如下图:

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界 当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

上图可以看出ARTIST生成图像质量的优越性。我们进一步比较我们先前公开的模型(看这里[3])和具有丰富知识的ARTIST模型的效果。在第一个示例“手工古风复原款发钗汉服配饰宫廷发簪珍珠头饰发冠”中,原始生成的结果主要突出了珍珠发冠这个物体。在ARTIST模型中,“古风”等词的知识注入过程使得模型生成结果会更偏向于古代中国的珍珠发簪。

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

第二个示例为“一颗绿色的花椰菜在生长”。由于模型在训练时对“花椰菜”物体样式掌握不够,当不包含知识注入模块时,模型根据“绿色”和“菜”的提示生成了有大片绿叶的单株植物。在ARTIST模型中,生成的物体更接近于形如花椰菜的椭圆形的植物。

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

ARTIST模型在MUGE榜单的评测结果

MUGE(Multimodal Understanding and Generation Evaluation[4])是业界首个大规模中文多模态评测基准,其中包括基于文本的图像生成任务。我们使用本次推出的ARTIST模型在中文MUGE评测榜单上验证了前述文图生成模型的效果。从下图可见,ARTIST模型生成的图像在FID指标(Frechet Inception Distance,值越低表示生成图像质量越好)上超越了榜单上的其他结果。

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

ARTIST模型的实现

在EasyNLP框架中,我们在模型层构建了ARTIST模型的Backbone,其主要是GPT,输入分别是token id和包含的实体的embedding,输出是图片各个patch对应的离散序列。其核⼼代码如下所示:

# in easynlp/appzoo/text2image_generation/model.py# initself.transformer = GPT_knowl(self.config)# forwardx = inputs['image']c = inputs['text']words_emb = inputs['words_emb']x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)# one step to produce the logits_, z_indices = self.encode_to_z(x) c_indices = ccz_indices = torch.cat((c_indices, a_indices), dim=1)# make the predictionlogits, _ = self.transformer(cz_indices[:, :-1], words_emb, flag=True)# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)logits = logits[:, c_indices.shape[1]-1:]

在数据预处理过程中,我们需要获得当前样本的输入文本和实体embedding,从而计算得到words_emb:

# in easynlp/appzoo/text2image_generation/data.py# preprocess word_matrixwords_mat = np.zeros([self.entity_num, self.text_len], dtype=np.int)if len(lex_id) > 0:    ents = lex_id.split(' ')[:self.entity_num]    pos_s = [int(x) for x in pos_s.split(' ')]    pos_e = [int(x) for x in pos_e.split(' ')]    ent_pos_s = pos_s[token_len:token_len+self.entity_num]    ent_pos_e = pos_e[token_len:token_len+self.entity_num]    for i, ent in enumerate(ents):        words_mat[i, ent_pos_s[i]:ent_pos_e[i]+1] = entencoding['words_mat'] = words_mat# in batch_fnwords_mat = torch.LongTensor([example['words_mat'] for example in batch])words_emb = self.embed(words_mat)

ARTIST模型使⽤教程

以下我们简要介绍如何在EasyNLP框架使⽤ARTIST模型。

安装EasyNLP

⽤户可以直接参考GitHub[2]上的说明安装EasyNLP算法框架。

数据准备

  1. 准备自己的数据,将image编码为base64形式:ARTIST在具体领域应用需要finetune, 需要用户准备下游任务的训练与验证数据,为tsv文件。这⼀⽂件包含以制表符\t分隔的三列(idx, text, imgbase64),第一列是文本编号,第二列是文本,第三列是对应图片的base64编码。样例如下:
64b4109e34a0c3e7310588c00fc9e157  韩国可爱日系袜子女中筒袜春秋薄款纯棉学院风街头卡通兔子长袜潮  iVBORw0KGgoAAAAN...MAAAAASUVORK5CYII=

下列⽂件已经完成预处理,可⽤于训练和测试:

https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.tsvhttps://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.tsvhttps://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.tsv

将输入数据与lattice、entity位置信息拼接到一起:输出格式为以制表符\t分隔的几列(idx, text, lex_ids, pos_s, pos_e, seq_len, [Optional] imgbase64)

# 下载entity to entity_id映射表wget wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2id.txtpython examples/text2image_generation/preprocess_data_knowl.py \    --input_file ./tmp/T2I_train.tsv \    --entity_map_file ./tmp/entity2id.txt \    --output_file ./tmp/T2I_knowl_train.tsvpython examples/text2image_generation/preprocess_data_knowl.py \    --input_file ./tmp/T2I_val.tsv \    --entity_map_file ./tmp/entity2id.txt \    --output_file ./tmp/T2I_knowl_val.tsvpython examples/text2image_generation/preprocess_data_knowl.py \    --input_file ./tmp/T2I_test.tsv \    --entity_map_file ./tmp/entity2id.txt \    --output_file ./tmp/T2I_knowl_test.tsv

ARTIST文图生成微调和预测示例

在文图生成任务中,我们对ARTIST进行微调,之后用于微调后对模型进行预测。相关示例代码如下:

# 下载entity_id与entity_vector的映射表wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2vec.pt# finetunepython -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \    --mode=train \    --worker_gpu=1 \    --tables=./tmp/T2I_knowl_train.tsv,./tmp/T2I_knowl_val.tsv \    --input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1,imgbase64:str:1,  \    --first_sequence=text \    --second_sequence=imgbase64 \    --checkpoint_dir=./tmp/artist_model_finetune \    --learning_rate=4e-5 \    --epoch_num=2 \    --random_seed=42 \    --logging_steps=100 \    --save_checkpoint_steps=200 \    --sequence_length=288 \    --micro_batch_size=8 \    --app_name=text2image_generation \    --user_defined_parameters='        pretrain_model_name_or_path=alibaba-pai/pai-artist-knowl-base-zh        entity_emb_path=./tmp/entity2vec.pt        size=256        text_len=32        img_len=256        img_vocab_size=16384      ' # predictpython -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \    --mode=predict \    --worker_gpu=1 \    --tables=./tmp/T2I_knowl_test.tsv \    --input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1, \    --first_sequence=text \    --outputs=./tmp/T2I_outputs_knowl.tsv \    --output_schema=idx,text,gen_imgbase64 \    --checkpoint_dir=./tmp/artist_model_finetune \    --sequence_length=288 \    --micro_batch_size=8 \    --app_name=text2image_generation \    --user_defined_parameters='        entity_emb_path=./tmp/entity2vec.pt        size=256        text_len=32        img_len=256        img_vocab_size=16384        max_generated_num=4      '

点击查看原文,获取更多福利!

https://developer.aliyun.com/article/1072675?groupCode=alitech?utm_content=g_1000363837

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

后台-插件-广告管理-内容页尾部广告(手机)
标签:

评论留言

我要留言

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。