SRPose / utils /transform.py
FrickYinn's picture
Upload 53 files
e170a8e verified
import torch
import numpy as np
########################################################################################################
# Based on the paper : On the Continuity of Rotation Representations in Neural Networks
# code from https://github.com/papagina/RotationContinuity/blob/master/Inverse_Kinematics/code/tools.py
# batch*n
def normalize_vector(v):
batch = v.shape[0]
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).to(v.device)))
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
v = v / v_mag
return v
# u, v batch*n
def cross_product(u, v):
batch = u.shape[0]
# print (u.shape)
# print (v.shape)
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3
return out
def rotation_matrix_from_ortho6d(poses):
"""
Computes rotation matrix from 6D continuous space according to the parametrisation proposed in
On the Continuity of Rotation Representations in Neural Networks
https://arxiv.org/pdf/1812.07035.pdf
:param poses: [B, 6]
:return: R: [B, 3, 3]
"""
x_raw = poses[:, 0:3] # batch*3
y_raw = poses[:, 3:6] # batch*3
x = normalize_vector(x_raw) # batch*3
z = cross_product(x, y_raw) # batch*3
z = normalize_vector(z) # batch*3
y = cross_product(z, x) # batch*3
x = x.view(-1, 3, 1)
y = y.view(-1, 3, 1)
z = z.view(-1, 3, 1)
matrix = torch.cat((x, y, z), 2) # batch*3*3
return matrix
#quaternion batch*4
def rotation_matrix_from_quaternion(quaternion):
batch=quaternion.shape[0]
quat = normalize_vector(quaternion).contiguous()
qw = quat[...,0].contiguous().view(batch, 1)
qx = quat[...,1].contiguous().view(batch, 1)
qy = quat[...,2].contiguous().view(batch, 1)
qz = quat[...,3].contiguous().view(batch, 1)
# Unit quaternion rotation matrices computatation
xx = qx*qx
yy = qy*qy
zz = qz*qz
xy = qx*qy
xz = qx*qz
yz = qy*qz
xw = qx*qw
yw = qy*qw
zw = qz*qw
row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
return matrix
def correct_intrinsic_scale(K, scale_x, scale_y):
'''Given an intrinsic matrix (3x3) and two scale factors, returns the new intrinsic matrix corresponding to
the new coordinates x' = scale_x * x; y' = scale_y * y
Source: https://dsp.stackexchange.com/questions/6055/how-does-resizing-an-image-affect-the-intrinsic-camera-matrix
'''
transform = np.eye(3)
transform[0, 0] = scale_x
transform[0, 2] = scale_x / 2 - 0.5
transform[1, 1] = scale_y
transform[1, 2] = scale_y / 2 - 0.5
Kprime = transform @ K
return Kprime