#include <bits/stdc++.h>
using namespace std;
//#include "grader.cpp"
#include "rect.h"
const int N=2505;
int tree[N+1];
void add(int i,int val){
for(++i;i<=N;i+=i&-i)tree[i]+=val;
}
int sum(int i){
int res=0;
for(++i;i>=1;i-=i&-i)res+=tree[i];
return res;
}
vector<int>v[N][N];
vector<pair<int,int>>cols[N][N],rows[N][N];
long long count_rectangles(vector<vector<int>>a){
int n=a.size(),m=a[0].size();
for (int i=0;i<n;i++){
vector<int>s;
for(int j=0;j<m;j++){
while(!s.empty()&&a[i][s.back()]<a[i][j])s.pop_back();
if(!s.empty()&&j-s.back()>=2)v[s.back()][j].push_back(i);
s.push_back(j);
}
s.clear();
for(int j=m-1;j>=0;j--){
while(!s.empty()&&a[i][s.back()]<a[i][j])s.pop_back();
if(!s.empty()&&s.back()-j>=2&&a[i][s.back()]!=a[i][j])v[j][s.back()].push_back(i);
s.push_back(j);
}
}
for(int i=0;i<m;i++){
for(int j=i+2;j<m;j++){
for (int l=0;l<v[i][j].size();){
int r=l+1;
while(r<v[i][j].size()&&v[i][j][r]==v[i][j][r-1]+1)r++;
for(int k=l;k<r;k++){
if(v[i][j][k])cols[v[i][j][k]-1][i].push_back({j,v[i][j][r-1]+1});
}
l=r;
}
}
}
for(auto&x:v)for(auto&y:x)y.clear();
for(int j=0;j<m;j++){
vector<int>s;
for(int i=0;i<n;i++){
while(!s.empty()&&a[s.back()][j]<a[i][j])s.pop_back();
if(!s.empty()&&i-s.back()>=2)v[s.back()][i].push_back(j);
s.push_back(i);
}
s.clear();
for(int i=n-1;i>=0;i--){
while(!s.empty()&&a[s.back()][j]<a[i][j])s.pop_back();
if(!s.empty()&&s.back()-i>=2&&a[s.back()][j]!=a[i][j])v[i][s.back()].push_back(j);
s.push_back(i);
}
}
for(int i=0;i<n;i++){
for(int j=i+2;j<n;j++){
for(int l=0;l<v[i][j].size();){
int r=l+1;
while(r<v[i][j].size()&&v[i][j][r]==v[i][j][r-1]+1)r++;
for(int k=l;k<r;k++){
if(v[i][j][k])rows[i][v[i][j][k]-1].push_back({j,v[i][j][r-1]+1});
}
l=r;
}
}
}
long long ans=0;
for(int i=0;i<n;i++){
for(int j=0;j<m;j++){
sort(rows[i][j].begin(),rows[i][j].end(),[](auto const&a,auto const&b){
return a.second>b.second;
});
sort(cols[i][j].begin(),cols[i][j].end(),[](auto const&a,auto const&b){
return a.first>b.first;
});
auto it=rows[i][j].begin();
for(auto jt=cols[i][j].begin();jt!=cols[i][j].end();jt++){
while(it!=rows[i][j].end()&&it->second>=jt->first){
add(it->first,1);
it++;
}
ans+=sum(jt->second);
}
for(auto jt=rows[i][j].begin();jt!=it;jt++)add(jt->first,-1);
}
}
return ans;
}