## Loss function for Autoencoder of sparse 3D Image

2

I have 3D structure data of molecules. I represented the atoms as points in a 100*100*100 grid and applied a gaussian blur to counter the sparseness. (nearly all of the grid cells contain zeros) I am trying to build an autoencoder to get a meaningful "molecule structure to vector" encoder.

My current approach is to use convolutional and max-pooling layers, then flatten and a few dense layers to get a vector representation. Then I reshape and increase the dimension again until the model predicts the probability of there being an atom in a grid-pixel with a sigmoid (see code below).

I am worried that the model does not learn if I use binary cross-entropy, because the data is so sparse. I want a loss function that punishes "not even close" atom predictions more than predictions that were just off by a few grid cells.

latent_dim= 512
input_mol = Input(shape=(100, 100, 100, 8))  # 8 channels for the different atom types

x = DepthwiseConv3D(kernel_size=(9,9,9), depth_multiplier=1,groups=8, padding ="same", use_bias=False)(input_mol) #gaussian blur
x = Conv3D(64, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((5, 5, 5))(x)
x = Conv3D(32, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((2, 2, 2))(x)
x = Conv3D(16, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((2, 2, 2))(x)
x = Flatten()(x)
x = Dense(1000, activation = 'relu')(x)
x = Dropout(rate=0.4)(x)
encoded = Dense(latent_dim, activation = 'relu')(x)

z_mean = Dense(latent_dim)(encoded)
z_log_sigma = Dense(latent_dim)(encoded)
z = Lambda(sampling, output_shape=(512,))([z_mean, z_log_sigma])

x= Reshape((8, 8, 8, 1))(encoded)

x = Conv3D(32, (3,3, 3), activation='relu', padding='same')(x)
x = UpSampling3D((2, 2,2))(x)
x = Conv3D(32, (3,3, 3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2,2))(x)
x = Conv3D(32, (3, 3,3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2, 2))(x)
x = Conv3D(8, (3, 3,3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2, 2))(x)
decoded = Conv3D(8, (10, 10, 10), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_mol, decoded)