GitHub链接:https://github.com/LiangYang666/prototypical-networks/tree/handbag
支持多gpu分布式训练,支持高版本pytorch1.x

EpisodicBatchSampler 抽样器

n_episodes: Number of episodes or equivalently batch size
n_way: Number of classes to sample
n_samples: Number of samples per episode (Usually n_query + n_support)
  • 部分参数
    • n_episodesbatch size
    • n_way 为每个batch中,抽取的训练类别数
    • n_samples为每一类中抽取的样本数量, 一般为查询集+支撑集数量
  • 每个batch采样生成的数据格式
    • Batch format: (c_i_1, c_j_1, ..., c_n_way_1, c_i_2, c_j_2, ... , c_n_way_2, ..., c_n_way_n_samples)
    • 不考虑图片维度,输出为一维,化为二维可看成,每一行n_way个,一共有n_samples

损失计算

  • class_prototypes,其shape(n_way_train, 512),为每一个类别其n_support个图经过网络生成的特征向量平均值
  • model(data_query), 其shape为(n_support*n_way_train, 512 )
  • logits, 其shape为(n_support*n_way_train, n_way), 为根据支撑集生成的特征向量,输出最有可能的类别
  • loss, 即交叉熵
Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐