#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pb push_back
#define mk make_pair
#define S second
#define int long long
void solve(){
int n,m;
cin>>n>>m;
int a[n][m];
map<int,vector<int>> vx,vy;
map<int,int> mp, sum;
int ans=0;
for(int i=0; i<n; i++){
for(int j=0; j<m; j++){
cin>>a[i][j];
vx[a[i][j]].push_back(i);
vy[a[i][j]].push_back(j);
mp[a[i][j]]++;
}
}
// int ans=0;
for(auto x: vx){
sort(x.second.begin(),x.second.end());
int sum=0, cnt=0;
for(int j: x.second){
if(cnt>0){
ans+=j*cnt-sum;
}
sum+=j;
cnt++;
}
// cout<<x.first<<" "<<ans<<endl;
}
for(auto y: vy){
sort(y.second.begin(),y.second.end());
int sum=0, cnt=0;
for(int j: y.second){
if(cnt>0){
ans+=j*cnt-sum;
}
sum+=j;
cnt++;
}
}
cout<<ans*2<<endl;
}
int main(){
int t=1; // cin >> t; while(t--) solve();
solve();
}