在pytorch中构建一个momentum queue一般使用torch.nn.Module.register_buffer函数,但是mindspore中没有类似的注册方法,所以只能新建一个Parameter来保存。对比如下:

# pytorch
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
# mindspore
randn = np.random.randn
image_queue = Parameter(normalize(Tensor(randn(self.K, config.hash_bit), mstype.float32)), 'image_queue', requires_grad=False)

在更新部分,mindspore框架的GRAPH模式不能直接对参数的slice赋值,因此只能使用矩阵乘法实现:

batch_size = image_feats.shape[0]

ptr = int(self.ptr_queue)
assert self.queue_size % batch_size == 0  # for simplicity

self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
ptr = (ptr + batch_size) % self.queue_size # move pointer

self.ptr_queue[0] = ptr  
# mindspore
# 在模型定义部分写
self.slide = Tensor(np.arange(0, self.K, 1), mstype.int32)

# 在更新方法里写
keys = ops.stop_gradient(k)
batch_size = keys.shape[0]

assert self.K % batch_size == 0
slide = self.slide
mask = logical_and((slide >= self.queue_ptr), (slide < (self.queue_ptr + batch_size)))
slide = cast(nonzero(mask * (slide + 1)).squeeze(), mstype.int32)
scatter_update(self.queue, slide, keys)
assign(self.queue_ptr, (self.queue_ptr + batch_size) % self.K)

实际上就是需要先构建一个选中要替换部分的mask,然后使用mask构建一个slide,最终根据这个slide指定的位置更新queue。

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐